Skip to content

[Enhancement] Add async support with event-driven state transitions #35

@lanceyliao

Description

@lanceyliao

Current Limitation

The current state_machine implementation only supports synchronous state transitions, which makes it difficult to handle:

  • Async operations in hooks (before/after callbacks)
  • Event-driven state transitions
  • Thread-safe state changes in concurrent environments

Proposed Enhancement

Add async support based on the following design:

1. Async Hooks & Decorators

@before_transition
async def risk_control(self, old_state, new_state) -> bool:
# Async validation
result = await self.check_risk()
return result
@after_transition(state=OrderState.PAYING)
async def log_payment(self, old_state, new_state):
await self.write_audit_log()

2. Event-Driven Transitions

  • Add async queue for state transition requests
  • Non-blocking transition requests
  • Ordered processing of concurrent transitions

3. Thread & Coroutine Safety

  • Use asyncio.Lock for state mutations
  • Queue-based state transition processing
  • Callback support for transition results

Example Implementation

I've created a working prototype that demonstrates these features:
[Link to your async.py gist/repo]

Key features:

  1. Async hooks with veto power
  2. Event queue for transition requests
  3. Callback mechanism for transition results
  4. Thread-safe state mutations

Benefits

  1. Better integration with async web frameworks (FastAPI, aiohttp)
  2. Natural handling of async operations (DB, API calls)
  3. Thread-safe state transitions
  4. Support for complex workflows (payment processing, etc.)

Breaking Changes

This would require:

  1. New async decorators (@before_transition_async, etc.)
  2. Modified base classes for async support
  3. Updated ORM adapters for async operations

Questions

  1. Would you consider adding async support as a separate module?
  2. Any concerns about backward compatibility?
  3. Thoughts on the proposed API design?

Let me know if you'd like to see more details from the prototype implementation.

for reference (generated by DeepSeek R1):

import asyncio
from enum import Enum, auto
from functools import wraps
from typing import Dict, List, Callable, Optional, Set, Awaitable
import threading

# ----------------------------
# 核心元类与装饰器定义
# ----------------------------

class AsyncStateMeta(type):
    """元类:自动收集异步钩子与委托处理器"""
    def __new__(cls, name, bases, attrs):
        # 初始化钩子存储
        attrs["_before_hooks"]: List[Callable] = []
        attrs["_after_hooks"]: Dict[Enum, List[Callable]] = {}
        attrs["_enter_hooks"]: Dict[Enum, List[Callable]] = {}
        attrs["_exit_hooks"]: Dict[Enum, List[Callable]] = {}
        attrs["_pending_requests"] = asyncio.Queue()
        
        # 收集装饰器标记的方法
        for method_name, method in attrs.items():
            if hasattr(method, "_hook_type"):
                hook_type = method._hook_type
                if hook_type == "before":
                    attrs["_before_hooks"].append(method)
                elif hook_type == "after":
                    state = method._state
                    attrs["_after_hooks"].setdefault(state, []).append(method)
                elif hook_type == "enter":
                    state = method._state
                    attrs["_enter_hooks"].setdefault(state, []).append(method)
                elif hook_type == "exit":
                    state = method._state
                    attrs["_exit_hooks"].setdefault(state, []).append(method)
        return super().__new__(cls, name, bases, attrs)

def before_transition(func: Callable) -> Callable:
    """全局异步 before 钩子(可否决转换)"""
    func._hook_type = "before"
    @wraps(func)
    async def wrapper(*args, **kwargs):
        return await func(*args, **kwargs)
    return wrapper

def after_transition(state: Enum) -> Callable:
    """状态专属 after 钩子装饰器工厂"""
    def decorator(func: Callable) -> Callable:
        func._hook_type = "after"
        func._state = state
        @wraps(func)
        async def wrapper(*args, **kwargs):
            return await func(*args, **kwargs)
        return wrapper
    return decorator

def on_enter(state: Enum) -> Callable:
    """进入状态时的异步钩子(可否决转换)"""
    def decorator(func: Callable) -> Callable:
        func._hook_type = "enter"
        func._state = state
        @wraps(func)
        async def wrapper(*args, **kwargs):
            return await func(*args, **kwargs)
        return wrapper
    return decorator

def on_exit(state: Enum) -> Callable:
    """退出状态时的异步钩子(可否决转换)"""
    def decorator(func: Callable) -> Callable:
        func._hook_type = "exit"
        func._state = state
        @wraps(func)
        async def wrapper(*args, **kwargs):
            return await func(*args, **kwargs)
        return wrapper
    return decorator

# ----------------------------
# 异步状态机基类
# ----------------------------

class AsyncStateMachine(metaclass=AsyncStateMeta):
    def __init__(self, initial_state: Enum):
        self._state = initial_state
        self._lock = asyncio.Lock()  # 异步锁
        self._event_loop = asyncio.get_event_loop()
        self._event_loop.create_task(self._process_requests())

    @property
    async def state(self) -> Enum:
        async with self._lock:
            return self._state

    async def request_transition(self, new_state: Enum, callback: Optional[Callable] = None):
        """非阻塞提交转换请求(委托模式)"""
        await self._pending_requests.put( (new_state, callback) )

    async def _process_requests(self):
        """独立任务:顺序处理转换请求"""
        while True:
            new_state, callback = await self._pending_requests.get()
            async with self._lock:
                old_state = self._state
                success = await self._execute_transition(old_state, new_state)
                if callback:
                    callback(success, old_state, new_state)

    async def _execute_transition(self, old_state: Enum, new_state: Enum) -> bool:
        """执行状态转换(原子操作)"""
        # ----------------------------
        # 阶段1:全局 before 钩子
        # ----------------------------
        for hook in self._before_hooks:
            allow = await hook(self, old_state, new_state)
            if not allow:  # 任何一个 before 钩子否决则终止
                return False

        # ----------------------------
        # 阶段2:旧状态 exit 钩子
        # ----------------------------
        for hook in self._exit_hooks.get(old_state, []):
            allow = await hook(self, old_state, new_state)
            if not allow:
                return False

        # ----------------------------
        # 阶段3:新状态 enter 钩子
        # ----------------------------
        for hook in self._enter_hooks.get(new_state, []):
            allow = await hook(self, old_state, new_state)
            if not allow:
                return False

        # ----------------------------
        # 执行状态更新
        # ----------------------------
        self._state = new_state

        # ----------------------------
        # 阶段4:触发 after 钩子(仅当转换成功)
        # ----------------------------
        for hook in self._after_hooks.get(new_state, []):
            await hook(self, old_state, new_state)

        return True

# ----------------------------
# 使用示例:订单系统
# ----------------------------

class OrderState(Enum):
    CREATED = auto()
    PAYING = auto()
    SHIPPED = auto()
    CANCELLED = auto()

class OrderSystem(AsyncStateMachine):
    def __init__(self):
        super().__init__(initial_state=OrderState.CREATED)

    # ------------
    # 全局 before 钩子:风控检查
    # ------------
    @before_transition
    async def risk_control(self, old: OrderState, new: OrderState) -> bool:
        print(f"[风控] 检查 {old.name}{new.name}")
        if new == OrderState.PAYING and old != OrderState.CREATED:
            print("风控阻止:非法支付路径")
            return False
        return True

    # ------------
    # 进入 PAYING 状态的钩子:调用支付网关
    # ------------
    @on_enter(OrderState.PAYING)
    async def start_payment(self, old: OrderState, new: OrderState) -> bool:
        print("调用支付接口...")
        # 模拟异步支付结果
        payment_success = await self._mock_payment_api()
        if not payment_success:
            print("支付失败,取消转换")
            return False  # 否决转换
        return True

    # ------------
    # 支付成功的 after 钩子:记录日志
    # ------------
    @after_transition(OrderState.PAYING)
    async def log_payment_success(self, old: OrderState, new: OrderState):
        print("支付成功,记录审计日志")

    async def _mock_payment_api(self) -> bool:
        await asyncio.sleep(1)  # 模拟网络延迟
        return True  # 改为 False 测试失败场景

# ----------------------------
# 测试异步委托与回调
# ----------------------------
async def main():
    order = OrderSystem()

    # 定义异步回调
    def on_transition_done(success: bool, old: OrderState, new: OrderState):
        print(f"\n回调结果: 转换 {old.name}{new.name} {'成功' if success else '失败'}")

    # 提交多个并发请求(将被顺序处理)
    await order.request_transition(OrderState.PAYING, on_transition_done)
    await order.request_transition(OrderState.SHIPPED, on_transition_done)  # 应被风控阻止
    await order.request_transition(OrderState.PAYING, on_transition_done)
    await order.request_transition(OrderState.CANCELLED, on_transition_done)

    # 等待队列处理完成
    await asyncio.sleep(3)

if __name__ == "__main__":
    asyncio.run(main())

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions