diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index c036b076..ece240a2 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -9,7 +9,7 @@ jobs: run: shell: bash container: - image: silkeh/clang:18 + image: silkeh/clang:19 options: --user root timeout-minutes: 10 steps: @@ -23,13 +23,33 @@ jobs: cmake --build build --target lin_check_test - name: Run lin check test run: ctest --test-dir build -R "^LinearizabilityCheckerCounterTest" -V + # pass-tests: + # runs-on: ubuntu-latest + # defaults: + # run: + # shell: bash + # container: + # image: silkeh/clang:19 + # options: --user root + # timeout-minutes: 10 + # steps: + # - name: Install deps + # run: apt update && apt install -y git ninja-build valgrind libboost-context-dev libgflags-dev + # - name: Check out repository code + # uses: actions/checkout@v4 + # - name: Build + # run: | + # cmake -G Ninja -B build -DCMAKE_BUILD_TYPE=RelWithAssert + # cmake --build build --target CoYieldPass + # - name: Run lin check test + # run: ctest --test-dir build -L llvm-pass -V verifying-test: runs-on: ubuntu-latest defaults: run: shell: bash container: - image: silkeh/clang:18 + image: silkeh/clang:19 options: --user root timeout-minutes: 10 steps: diff --git a/Dockerfile b/Dockerfile index 377dbfd1..c52b4fb0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,7 @@ FROM silkeh/clang:19 AS ltest RUN apt update && apt install -y git ninja-build valgrind libboost-context-dev libgflags-dev libstdc++-11-dev RUN mv /usr/lib/gcc/x86_64-linux-gnu/12 /usr/lib/gcc/x86_64-linux-gnu/_12 -FROM ltest as blocking +FROM ltest AS blocking RUN apt install -y pkg-config libcapstone-dev && \ git clone https://github.com/Kirillog/syscall_intercept.git && \ cmake syscall_intercept -G Ninja -B syscall_intercept/build -DCMAKE_INSTALL_PREFIX=/usr -DCMAKE_BUILD_TYPE=Release -DCMAKE_C_COMPILER=clang && \ diff --git a/codegen/coyieldpass.cpp b/codegen/coyieldpass.cpp index d97589aa..fd217551 100644 --- a/codegen/coyieldpass.cpp +++ b/codegen/coyieldpass.cpp @@ -1,3 +1,10 @@ + +#include +#include +#include +#include +#include +#include #include #include #include @@ -6,6 +13,8 @@ #include #include #include +#include +#include #include #include #include @@ -13,9 +22,11 @@ #include #include #include +#include #include -#include +#include +#include #include #include #include @@ -32,83 +43,216 @@ using namespace llvm; using Builder = IRBuilder<>; constexpr std::string_view costatus_change = "CoroutineStatusChange"; +constexpr std::string_view create_thread = "CreateNewVirtualThread"; +constexpr std::string_view wait_thread = "WaitForThread"; +constexpr std::string_view spawned_coro_start = "VirtualThreadStartPoint"; -constexpr std::string_view co_expr_start = "::await_ready"; - -constexpr std::string_view co_expr_end = "::await_resume"; -constexpr std::string_view co_initial_suspend = "::initial_suspend()"; -constexpr std::string_view co_final_suspend = "::final_suspend()"; - -constexpr std::string_view no_filter = "any"; - +constexpr std::string_view co_await_ready = "await_ready"; +constexpr std::string_view co_initial_suspend = "initial_suspend()"; +constexpr int resumed_coro = 0; static cl::opt input_list( - "coroutine-file", cl::desc("Specify path to file with coroutines to check"), + "coroutine-file", cl::desc("Specify path to file with config"), llvm::cl::Required); ; -struct CoroutineFilter { - CoroutineFilter() = default; - CoroutineFilter(const std::optional &parent_name, - const std::optional &co_name, - const std::string &print_name) - : parent_name(parent_name), co_name(co_name), print_name(print_name) {}; - std::optional parent_name; + +constexpr bool dump_before = false; +constexpr bool dump_after = false; + +enum HandleType { + GENERIC_FUN, + CORO_FUN, + SPAWN_VIRT_THREAD, + SPAWN_CORO_INFO, + WAIT_VIRT_THREAD +}; + +struct InsertPlace { + InsertPlace() = default; std::optional co_name; + std::optional parent_name; + // i believe that adding column is overkill + std::optional debug_line; + std::optional debug_file; +}; + +struct InsertAction { + InsertPlace place; + HandleType type; + virtual HandleType GetType() = 0; + virtual ~InsertAction() {} +}; + +struct InsertActionWithName : InsertAction { std::string print_name; + InsertActionWithName(const std::string &print_name) + : print_name(print_name) {} + virtual ~InsertActionWithName() {} +}; + +struct InsertCoro : InsertActionWithName { + using InsertActionWithName::InsertActionWithName; + HandleType GetType() override { return HandleType::CORO_FUN; } }; +struct InsertGeneric : InsertActionWithName { + using InsertActionWithName::InsertActionWithName; + HandleType GetType() override { return HandleType::GENERIC_FUN; } +}; + +struct InsertSpawn : InsertAction { + int creation_id; + bool has_this; + HandleType GetType() override { return HandleType::SPAWN_VIRT_THREAD; } + InsertSpawn(int creation_id, bool has_this) + : creation_id(creation_id), has_this(has_this) {} +}; + +struct InsertSpawnedCoro : InsertActionWithName { + std::optional args_fun_name; + HandleType GetType() override { return HandleType::SPAWN_CORO_INFO; } + InsertSpawnedCoro(const std::optional &args_fun, + const std::string &print_name) + : args_fun_name(args_fun), InsertActionWithName(print_name) {} +}; + +struct InsertWaitThread : InsertAction { + std::vector wait_for_ids; + HandleType GetType() override { return HandleType::WAIT_VIRT_THREAD; } + InsertWaitThread(const std::vector &wait) : wait_for_ids(wait) {} +}; + +using InsertActionPtr = std::shared_ptr; + namespace llvm { namespace yaml { +static std::map> + construct_action{{"Generic", + [](IO &io) { + std::string print_name; + io.mapRequired("Name", print_name); + return std::make_shared(print_name); + }}, + {"Coro", + [](IO &io) { + std::string print_name; + io.mapRequired("Name", print_name); + return std::make_shared(print_name); + }}, + {"Spawn", + [](IO &io) { + int creation_id; + bool has_this; + io.mapRequired("CreationId", creation_id); + io.mapRequired("HasThis", has_this); + + return std::make_shared(creation_id, + has_this); + }}, + {"SpawnedCoro", + [](IO &io) { + std::optional args_fun_name; + std::string print_name; + io.mapOptional("ArgsFun", args_fun_name); + io.mapRequired("Name", print_name); + return std::make_shared( + args_fun_name, print_name); + }}, + {"Wait", [](IO &io) { + std::vector vect; + io.mapRequired("WaitsFor", vect); + assert(vect.size() > 0); + return std::make_shared(vect); + }}}; +template <> +struct MappingTraits { + static void mapping(IO &io, + InsertActionPtr &action) { // NOLINT + std::string type; + // theoreticaly it should work with tags, but it don't work( + io.mapRequired("Action", type); + auto entry = construct_action.find(type); + if (entry == construct_action.end()) { + io.setError("Got unexpected action - " + type); + return; + } + action = (entry->second)(io); + io.mapRequired("Place", action->place); + } +}; + template <> -struct MappingTraits { - static void mapping(IO &io, CoroutineFilter &cofilter) { // NOLINT - io.mapRequired("Name", cofilter.print_name); - io.mapOptional("Coroutine", cofilter.co_name); +struct MappingTraits { + static void mapping(IO &io, InsertPlace &cofilter) { // NOLINT + io.mapOptional("Function", cofilter.co_name); io.mapOptional("Parent", cofilter.parent_name); + + io.mapOptional("Line", cofilter.debug_line); + io.mapOptional("File", cofilter.debug_file); } }; } // namespace yaml } // namespace llvm -LLVM_YAML_IS_SEQUENCE_VECTOR(CoroutineFilter); +LLVM_YAML_IS_SEQUENCE_VECTOR(InsertActionPtr); struct CoYieldInserter { - CoYieldInserter(Module &m, std::vector &&co_filter) - : m(m), co_filter(std::move(co_filter)) { + Builder builder; + CoYieldInserter(Module &m, std::vector &&co_filter) + : m(m), co_filter(std::move(co_filter)), builder(m.getContext()) { auto &context = m.getContext(); - coroYieldF = m.getOrInsertFunction( + auto *char_ptr = PointerType::get(Type::getInt8Ty(context), 0); + coro_yield_f = m.getOrInsertFunction( costatus_change, FunctionType::get(Type::getVoidTy(context), - {PointerType::get(Type::getInt8Ty(context), 0), - Type::getInt8Ty(context)}, - {})); + {char_ptr, Type::getInt1Ty(context)}, {})); + create_thread_f = m.getOrInsertFunction( + create_thread, + FunctionType::get( + Type::getVoidTy(context), + {Type::getInt32Ty(context), PointerType::get(context, 0)}, {})); + wait_thread_f = m.getOrInsertFunction( + wait_thread, + FunctionType::get( + Type::getVoidTy(context), + {PointerType::get(context, 0), Type::getInt32Ty(context)}, {})); + spawned_coro_start_f = m.getOrInsertFunction( + spawned_coro_start, + FunctionType::get(Type::getVoidTy(context), {char_ptr, char_ptr}, {})); } - void Run(const Module &index) { + void Run() { + if (dump_before) { + m.dump(); + errs().flush(); + } for (auto &f : m) { + if (ignored.contains(&f)) { + continue; + } std::string demangled = demangle(f.getName()); auto filt = co_filter | std::ranges::views::filter( - [&demangled](const CoroutineFilter &a) -> bool { - return !a.parent_name || a.parent_name == demangled; + [&demangled](const InsertActionPtr &a) -> bool { + const auto &place = a->place; + return !place.parent_name || + place.parent_name == demangled; }); if (!filt.empty()) { - InsertYields(filt, f); + InsertContextSwitchFunctions(filt, f); } } + for (auto inst : to_delete) { + inst->eraseFromParent(); + } + if (dump_after) { + m.dump(); + errs().flush(); + } } private: - void InsertYields(auto filt, Function &f) { + void InsertContextSwitchFunctions(auto filt, Function &f) { Builder builder(&*f.begin()); - /* - In fact co_await expr when expr is coroutine is - co_await initial_suspend() - coro body... - co_await_final_suspend() - We are interested to insert only before initial_suspend and - after final_suspend - */ - int skip_insert_points = 0; for (auto &b : f) { for (auto &i : b) { CallBase *call = dyn_cast(&i); @@ -121,108 +265,464 @@ struct CoYieldInserter { } auto raw_fn_name = c_fn->getName(); std::string co_name = demangle(raw_fn_name); - bool is_call_inst = isa(call); + if (co_name == costatus_change) { + continue; + } + CallInst *call_inst = dyn_cast(call); InvokeInst *invoke = dyn_cast(call); - if (is_call_inst || invoke) { - auto res_filt = - filt | std::ranges::views::filter( - [&co_name](const CoroutineFilter &a) -> bool { - return !a.co_name || a.co_name == co_name; - }); - if (!res_filt.empty()) { - auto filt_entry = res_filt.front(); - errs() << "inserted " << filt_entry.print_name << "\n"; - builder.SetInsertPoint(call); - InsertCall(filt_entry, builder, true); - // Invoke instruction has unwind/normal ends so we need handle it - if (invoke) { - builder.SetInsertPoint(invoke->getNormalDest()->getFirstInsertionPt()); - InsertCall(filt_entry, builder, false); - builder.SetInsertPoint(invoke->getUnwindDest()->getFirstInsertionPt()); - InsertCall(filt_entry, builder, false); - } else { - builder.SetInsertPoint(call->getNextNode()); - InsertCall(filt_entry, builder, false); - } - continue; + + if (call_inst || invoke) { + // filt and filt | filter have different types + if (auto debugLoc = call->getDebugLoc()) { + auto place_filt = + filt | + std::ranges::views::filter( + [&co_name, &debugLoc](const InsertActionPtr &a) -> bool { + const auto &place = a->place; + if (place.debug_file.has_value() && + place.debug_file != + debugLoc->getFile()->getFilename()) { + return false; + } + if (place.debug_line.has_value() && + place.debug_line != debugLoc.getLine()) { + return false; + } + return true; + }); + InsertContextSwitchFun(place_filt, call_inst, invoke, co_name); + + } else { + InsertContextSwitchFun(filt, call_inst, invoke, co_name); } } - if (!is_call_inst) { - continue; + } + } + } + + void InsertContextSwitchFun(auto filt, CallInst *call, InvokeInst *invoke, + std::string co_name) { + auto await_ready_ind = co_name.find(co_await_ready); + if (await_ready_ind != std::string::npos) { + auto res_filt = + filt | std::ranges::views::filter( + [&co_name](const InsertActionPtr &a) -> bool { + const auto &place = a->place; + return !place.co_name || + co_name.find(*place.co_name) != std::string::npos; + }); + if (!res_filt.empty()) { + auto entry = res_filt.front(); + switch (entry->GetType()) { + case HandleType::CORO_FUN: { + errs() << "inserted coro handled by type " << co_name << "\n"; + auto act_name = + std::static_pointer_cast(entry); + auto start = [this, &act_name]() { + InsertYieldCall(act_name, true); + }; + auto end = [this, &act_name]() { + InsertYieldCall(act_name, false); + }; + HandleCoroCase(call, res_filt.front(), start, end); + return; + } + default: + return; } - auto initial = co_name.find(co_initial_suspend); - if (initial != std::string::npos) { - builder.SetInsertPoint(call); - InsertCallWithFilter(filt, co_name, builder, true, initial); - skip_insert_points = 2; - continue; + } + return; + } + auto res_filt = + filt | std::ranges::views::filter( + [&co_name](const InsertActionPtr &a) -> bool { + const auto &place = a->place; + return !place.co_name || place.co_name == co_name; + }); + + if (!res_filt.empty()) { + auto entry = res_filt.front(); + switch (entry->GetType()) { + case HandleType::CORO_FUN: { + errs() << "inserted coro handled by func name " << co_name << "\n"; + auto act_name = std::static_pointer_cast(entry); + auto start = [this, &act_name]() { InsertYieldCall(act_name, true); }; + auto end = [this, &act_name]() { InsertYieldCall(act_name, false); }; + if (invoke) { + assert(FindAwaitReady(invoke->getNormalDest()->begin(), entry, + start, end)); + } else { + assert(FindAwaitReady(BasicBlock::iterator(call->getNextNode()), + entry, start, end)); + } + break; } - auto final = co_name.find(co_final_suspend); - if (final != std::string::npos) { - builder.SetInsertPoint(call->getNextNode()); - InsertCallWithFilter(filt, co_name, builder, false, final); - skip_insert_points = 2; - continue; + case HandleType::GENERIC_FUN: { + errs() << "inserted generic " << co_name << "\n"; + HandleGenericFunCase( + call, invoke, + std::static_pointer_cast(entry)); + break; } - auto start = co_name.find(co_expr_start); - if (start != std::string::npos) { - if (skip_insert_points != 0) { - assert(skip_insert_points == 2); - skip_insert_points--; - continue; + case HandleType::SPAWN_VIRT_THREAD: { + errs() << "inserted spawn of new thread " << co_name << "\n"; + CallBase *inst = call ? static_cast(call) : invoke; + Function *called_fun = inst->getCalledFunction(); + + assert(!inst->arg_empty()); + auto [wrapper_fun, storage] = InsertZeroArgsWrapper(inst); + builder.SetInsertPoint(inst); + for (size_t i = 0; i < called_fun->arg_size(); i++) { + Value *arg = inst->getArgOperand(i); + + Value *storage_place = + builder.CreateGEP(storage->getValueType(), storage, + { + builder.getInt32(0), + builder.getInt32(i), + }); + builder.CreateStore(arg, storage_place); } - builder.SetInsertPoint(call); - InsertCallWithFilter(filt, co_name, builder, true, start); - continue; + Value *pointer_to_func = builder.CreatePointerCast( + wrapper_fun, PointerType::get(builder.getContext(), 0)); + std::shared_ptr spawn = + std::static_pointer_cast(entry); + Value *id = builder.getInt32(spawn->creation_id); + builder.SetInsertPoint(inst); + Value *replacement; + if (call) { + replacement = + builder.CreateCall(create_thread_f, {id, pointer_to_func}); + } else { + replacement = builder.CreateInvoke( + create_thread_f, invoke->getNormalDest(), + invoke->getUnwindDest(), {id, pointer_to_func}); + } + inst->replaceAllUsesWith(replacement); + // we cannot simple delete here instruction because we are + // iterating over it in basic block + to_delete.push_back(inst); + break; } - auto end_pos = co_name.find(co_expr_end); - if (end_pos != std::string::npos) { - if (skip_insert_points != 0) { - assert(skip_insert_points == 1); - skip_insert_points--; - continue; + case HandleType::SPAWN_CORO_INFO: { + CallBase *inst = call ? static_cast(call) : invoke; + errs() << "inserted coroutine info " << co_name << "\n"; + auto spawn_coro = std::static_pointer_cast(entry); + Function *func = nullptr; + if (spawn_coro->args_fun_name) { + func = m.getFunction(*spawn_coro->args_fun_name); + assert(func); } - builder.SetInsertPoint(call->getNextNode()); - InsertCallWithFilter(filt, co_name, builder, false, end_pos); - continue; + InsertAtBodyStart(inst, func, [this, &spawn_coro](Value *args) { + builder.CreateCall(spawned_coro_start_f, + {GetLiteral(spawn_coro->print_name), + GetLiteral(spawn_coro->print_name)}); + builder.CreateAnd(ConstantInt::get(builder.getInt16Ty(), 1), + ConstantInt::get(builder.getInt16Ty(), 1)); + }); + break; } + case HandleType::WAIT_VIRT_THREAD: { + errs() << "inserted wait of new thread " << co_name << "\n"; + // I'm sure that here must be only coro case + auto start = std::bind_front( + &CoYieldInserter::InsertWaitFunc, this, + std::static_pointer_cast(entry)); + if (invoke) { + assert(InsertBeforeAnyCoroCall(invoke->getNormalDest()->begin(), + start)); + } else { + assert(InsertBeforeAnyCoroCall( + BasicBlock::iterator(call->getNextNode()), start)); + } + break; + } + default: + __builtin_unreachable(); } } } + void InsertWaitFunc(const std::shared_ptr &act) { + // because in c++ this function will have pointer as argument, we need also + // explicitly pass the size + std::vector elements; + for (auto &a : act->wait_for_ids) { + elements.push_back(builder.getInt32(a)); + } + Constant *indices = ConstantArray::get( + ArrayType::get(builder.getInt32Ty(), act->wait_for_ids.size()), + elements); + GlobalVariable *val_ind = + new GlobalVariable(m, indices->getType(), true, + llvm::GlobalValue::InternalLinkage, indices); + builder.CreateCall(wait_thread_f, + {val_ind, builder.getInt32(elements.size())}); + } - void InsertCallWithFilter(auto filt, StringRef co_name, Builder &builder, - bool start, int end_pos) { - auto res_filt = - filt | std::ranges::views::filter( - [&end_pos, &co_name](const CoroutineFilter &a) -> bool { - return !a.co_name || - a.co_name == co_name.substr(0, end_pos); - }); - if (res_filt.empty()) { - return; + // We need pass to scheduler function and wan't to care about number + // of args and their type to not interact with templates - so lets create a + // wrapper which would have zero args + std::pair InsertZeroArgsWrapper( + CallBase *call_inst) { + Function *func = + Function::Create(FunctionType::get(Type::getVoidTy(m.getContext()), {}), + GlobalValue::PrivateLinkage, "", m); + ignored.insert(func); + std::vector types; + for (auto &arg : call_inst->args()) { + types.push_back(arg->getType()); } - errs() << "inserted " << co_name.str() << "\n"; - // First in the config will match - InsertCall(res_filt.front(), builder, start); + StructType *storage_type = StructType::create(types); + GlobalVariable *storage = + new GlobalVariable(m, storage_type, false, GlobalValue::PrivateLinkage, + Constant::getNullValue(storage_type)); + BasicBlock *block = BasicBlock::Create(builder.getContext(), "", func); + builder.SetInsertPoint(block); + std::vector args; + for (size_t i = 0; i < types.size(); i++) { + Value *load = builder.CreateGEP( + storage_type, storage, {builder.getInt32(0), builder.getInt32(i)}); + args.push_back(builder.CreateLoad(types[i], load)); + } + builder.CreateCall(call_inst->getCalledFunction(), {args}); + builder.CreateRetVoid(); + return {func, storage}; } - void InsertCall(const CoroutineFilter &filt, Builder &builder, bool start) { + bool FindAwaitReady(BasicBlock::iterator start, InsertActionPtr &entry, + const std::function &start_insert, + const std::function &end_insert) { + for (Instruction &n_inst : make_range(start, start->getParent()->end())) { + auto *call_inst = dyn_cast(&n_inst); + if (!call_inst) { + continue; + } + auto await_ready_ind = demangle(call_inst->getCalledFunction()->getName()) + .find(co_await_ready); + if (await_ready_ind != std::string::npos) { + HandleCoroCase(call_inst, entry, start_insert, end_insert); + return true; + } + // If Coro Type constructor can throw we need go deeper + if (auto *invoke = dyn_cast(call_inst)) { + return FindAwaitReady(invoke->getNormalDest()->begin(), entry, + start_insert, end_insert); + } + } + return false; + } + + bool InsertBeforeAnyCoroCall(BasicBlock::iterator start, + const std::function &insert) { + for (Instruction &n_inst : make_range(start, start->getParent()->end())) { + auto *call_inst = dyn_cast(&n_inst); + if (!call_inst) { + continue; + } + auto await_ready_ind = demangle(call_inst->getCalledFunction()->getName()) + .find(co_await_ready); + if (await_ready_ind != std::string::npos) { + builder.SetInsertPoint(call_inst); + insert(); + return true; + } + } + return false; + } + // This case is needed at sample by some coro primitives where the + // normal function which is the body of coro is called in loop + void HandleGenericFunCase( + CallBase *call, InvokeInst *invoke, + const std::shared_ptr &filt_entry) { + builder.SetInsertPoint(call); + InsertYieldCall(filt_entry, true); + // Invoke instruction has unwind/normal ends so we need handle it + if (invoke) { + builder.SetInsertPoint(invoke->getNormalDest()->getFirstInsertionPt()); + InsertYieldCall(filt_entry, false); + builder.SetInsertPoint(invoke->getUnwindDest()->getFirstInsertionPt()); + InsertYieldCall(filt_entry, false); + } else { + builder.SetInsertPoint(call->getNextNode()); + InsertYieldCall(filt_entry, false); + } + } + + void HandleCoroCase(CallBase *call, const InsertActionPtr &filt_entry, + const std::function &start_insert, + const std::function &end_insert) { + BranchInst *br = dyn_cast(call->getNextNode()); + assert(br && br->getNumSuccessors() == 2); + BasicBlock *not_ready_bb = br->getSuccessor(1); + for (auto &i : *not_ready_bb) { + CallBase *call_base = dyn_cast(&i); + if (!call_base) { + continue; + } + + Intrinsic::ID id = call_base->getIntrinsicID(); + switch (id) { + // We cannot insert after await_suspend because inside it we can + // already interact with handle, so we must we do it before + case Intrinsic::coro_await_suspend_bool: { + if (start_insert) { + builder.SetInsertPoint(call_base); + start_insert(); + } + if (end_insert) { + BranchInst *suspend_br = dyn_cast(i.getNextNode()); + assert(suspend_br && suspend_br->getNumSuccessors() == 2); + builder.SetInsertPoint( + suspend_br->getSuccessor(0)->getFirstInsertionPt()); + // handled if await_suspend was true, now change block also for + // false + BasicBlock *tramp = InsertAtEnd( + builder, &(*builder.GetInsertPoint()), filt_entry, end_insert); + suspend_br->setSuccessor(1, tramp); + } + return; + } + case Intrinsic::coro_await_suspend_void: { + if (start_insert) { + builder.SetInsertPoint(call_base); + start_insert(); + } + if (end_insert) { + InsertAtEnd(builder, nextNormal(call_base), filt_entry, end_insert); + } + return; + } + case Intrinsic::coro_await_suspend_handle: { + if (start_insert) { + builder.SetInsertPoint(call_base); + start_insert(); + } + if (end_insert) { + InsertAtEnd(builder, nextNormal(call_base), filt_entry, end_insert); + } + return; + } + default: { + continue; + } + } + } + assert(false && "Haven't found await_suspend intrisinc"); + } + + Instruction *nextNormal(CallBase *inst) { + if (auto invoke = dyn_cast(inst)) { + return invoke->getNormalDest()->getFirstNonPHI(); + } else { + return inst->getNextNode(); + } + } + BasicBlock *InsertAtEnd(Builder &builder, Instruction *instr, + const InsertActionPtr &filt_entry, + const std::function &end_insert) { + CallInst *intr = dyn_cast(instr); + assert(intr && intr->getIntrinsicID() == Intrinsic::coro_suspend); + SwitchInst *switch_inst = dyn_cast(intr->getNextNode()); + assert(switch_inst); + auto resumed_bb = switch_inst->findCaseValue( + ConstantInt::get(Type::getInt8Ty(builder.getContext()), resumed_coro)); + auto succ = resumed_bb->getCaseSuccessor(); + // If we would simple insert in the block we would have extra ends, so we + // need to add a trampoline + BasicBlock *tramp = + BasicBlock::Create(builder.getContext(), "", succ->getParent()); + resumed_bb->setSuccessor(tramp); + builder.SetInsertPoint(tramp); + end_insert(); + builder.CreateBr(succ); + return tramp; + } + void InsertYieldCall(const std::shared_ptr &filt, + bool start) { auto llvm_start = ConstantInt::get(Type::getInt1Ty(builder.getContext()), start); - Constant *str_const = - ConstantDataArray::getString(m.getContext(), filt.print_name, true); - auto zero = ConstantInt::get(Type::getInt32Ty(m.getContext()), 0); - Constant *ind[] = {zero, zero}; - GlobalVariable *global = new GlobalVariable( - m, str_const->getType(), true, GlobalValue::PrivateLinkage, str_const); - auto ptr = - ConstantExpr::getGetElementPtr(global->getValueType(), global, ind); - builder.CreateCall(coroYieldF, {ptr, llvm_start}); + builder.CreateCall(coro_yield_f, + {GetLiteral(filt->print_name), llvm_start}); } + void InsertAtBodyStart(CallBase *call, Function *insert_at_start_fun, + const std::function &insert) { + auto f = call->getCalledFunction(); + ValueToValueMapTy vmap; + auto cloned_f = CloneFunction(f, vmap); + call->setCalledFunction(cloned_f); + // for simplity and correct handling of references let's call print function + // on the start + Value *print_args = + ConstantPointerNull::get(PointerType::get(m.getContext(), 0)); + if (insert_at_start_fun) { + builder.SetInsertPoint(cloned_f->front().getFirstInsertionPt()); + SmallVector args; + for (auto &a : cloned_f->args()) { + args.emplace_back(&a); + } + print_args = builder.CreateCall(insert_at_start_fun, ArrayRef(args)); + } + for (auto &b : *cloned_f) { + for (auto &i : b) { + if (auto call = dyn_cast(&i)) { + auto f = call->getCalledFunction(); + if (!f) { + continue; + } + std::string demangled = demangle(f->getName()); + auto initial = demangled.find(co_initial_suspend); + if (initial != std::string::npos) { + // we have only one initial_suspend; + auto *await_ready = i.getNextNode(); + auto *br = dyn_cast(await_ready->getNextNode()); + // true case is shorter; + auto *ready_bb = br->getSuccessor(0); + auto *resume_bb = ready_bb->getSingleSuccessor(); + auto it = resume_bb->getFirstInsertionPt(); + //iterate until await resume is meet + while(true){ + auto* call = dyn_cast(&(*it)); + if(call){ + break; + } + it++; + } + builder.SetInsertPoint(++it); + insert(print_args); + return; + } + } + } + } + assert(false && "no initial suspend"); + } + + Constant *GetLiteral(const std::string &name) { + auto literal = string_literals.find(name); + if (literal == string_literals.end()) { + Constant *str_const = + ConstantDataArray::getString(m.getContext(), name, true); + auto zero = ConstantInt::get(Type::getInt32Ty(m.getContext()), 0); + std::array ind = {zero, zero}; + GlobalVariable *global = + new GlobalVariable(m, str_const->getType(), true, + GlobalValue::PrivateLinkage, str_const); + auto ptr = + ConstantExpr::getGetElementPtr(global->getValueType(), global, ind); + literal = string_literals.emplace(name, ptr).first; + } + return literal->second; + } Module &m; - FunctionCallee coroYieldF; - std::vector co_filter; + FunctionCallee coro_yield_f; + FunctionCallee spawned_coro_start_f; + FunctionCallee create_thread_f; + FunctionCallee wait_thread_f; + std::vector co_filter; + std::map string_literals; + std::vector to_delete; + std::set ignored; }; namespace { @@ -230,23 +730,25 @@ namespace { struct CoYieldInsertPass final : public PassInfoMixin { PreservedAnalyses run(Module &m, ModuleAnalysisManager &am) { // NOLINT if (input_list.empty()) { - report_fatal_error("No file with coroutines list"); + report_fatal_error("No file with coroutines list", false); } auto file = llvm::MemoryBuffer::getFile(input_list); if (!file) { - report_fatal_error("Failed to load config file\n"); + errs() << "Tried to read file " << input_list << "\n"; + report_fatal_error("Failed to load config file \n", false); } llvm::yaml::Input input(file.get()->getBuffer()); - std::vector filt; + std::vector filt; input >> filt; if (input.error()) { - report_fatal_error("Error parsing YAML\n"); + errs() << "Tried to parse file " << input_list << "\n"; + report_fatal_error("Error parsing YAML\n", false); } CoYieldInserter gen{m, std::move(filt)}; - gen.Run(m); + gen.Run(); return PreservedAnalyses::none(); }; }; @@ -259,9 +761,21 @@ llvmGetPassPluginInfo() { .PluginName = "coyield_insert", .PluginVersion = "v0.1", .RegisterPassBuilderCallbacks = [](PassBuilder &pb) { + // This parsing we need for testing with opt + pb.registerPipelineParsingCallback( + [](StringRef Name, ModulePassManager &mpm, + ArrayRef) { + if (Name == "coyield_insert") { + mpm.addPass(CoYieldInsertPass()); + return true; + } + return false; + }); pb.registerPipelineStartEPCallback( [](ModulePassManager &mpm, OptimizationLevel level) { - std::set l; + // Looks like we don't need any lowerings, before but i'm + // not sure + // mpm.addPass(CoroEarlyPass()); mpm.addPass(CoYieldInsertPass()); }); }}; diff --git a/runtime/include/lib.h b/runtime/include/lib.h index b629c970..b7197518 100644 --- a/runtime/include/lib.h +++ b/runtime/include/lib.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -18,6 +19,8 @@ struct CoroBase; struct CoroutineStatus; +struct CreatedThreadInfo; +struct WaitThreadInfo; // Current executing coroutine. extern std::shared_ptr this_coro; @@ -26,11 +29,31 @@ extern boost::context::fiber_context sched_ctx; extern std::optional coroutine_status; -struct CoroutineStatus{ +extern std::optional virtual_thread_creation; + +extern std::optional virtual_thread_wait; + +struct CoroutineStatus { std::string_view name; bool has_started; }; +struct CreatedThreadInfo { + std::function function; + int id; + // this will be set after start of scheduler + // name is a literal stored in .data, but args will be set in runtime + std::string_view name = ""; + std::string args = ""; + // will be set in scheduler + size_t parent = -1; + bool has_started = false; +}; + +struct WaitThreadInfo { + std::vector wait_ids; +}; + // Runtime token. // Target method could use token generator. struct Token { @@ -52,6 +75,12 @@ extern "C" void CoroYield(); extern "C" void CoroutineStatusChange(char* coroutine, bool start); +extern "C" void CreateNewVirtualThread(int id, void* func); + +extern "C" void VirtualThreadStartPoint(char* name, char* args); + +extern "C" void WaitForThread(int* ids, int size); + struct CoroBase : public std::enable_shared_from_this { CoroBase(const CoroBase&) = delete; CoroBase(CoroBase&&) = delete; @@ -79,6 +108,10 @@ struct CoroBase : public std::enable_shared_from_this { // Returns the args as strings. virtual std::vector GetStrArgs() const = 0; + virtual void SetStrArgsAndName( + std::string_view name, + std::function(std::shared_ptr)> args) = 0; + // Returns raw pointer to the tuple arguments. virtual void* GetArgs() const = 0; @@ -87,7 +120,7 @@ struct CoroBase : public std::enable_shared_from_this { std::shared_ptr GetPtr(); // Terminate the coroutine. - void Terminate(); + bool Terminate(); // Sets the token. void SetToken(std::shared_ptr); @@ -193,6 +226,14 @@ struct Coro final : public CoroBase { return args_to_strings(args); } + void SetStrArgsAndName( + std::string_view name, + std::function(std::shared_ptr)> args_fun) + override { + this->name = name; + this->args_to_strings = args_fun; + } + void* GetArgs() const override { return args.get(); } private: diff --git a/runtime/include/lincheck.h b/runtime/include/lincheck.h index 178fd3d5..9ea4f89a 100644 --- a/runtime/include/lincheck.h +++ b/runtime/include/lincheck.h @@ -17,7 +17,7 @@ struct Response { ValueWrapper result; int thread_id; - private: +// private: std::reference_wrapper task; }; @@ -28,7 +28,7 @@ struct Invoke { int thread_id; - private: +// private: std::reference_wrapper task; }; diff --git a/runtime/include/pretty_print.h b/runtime/include/pretty_print.h index 145dc974..51008e81 100644 --- a/runtime/include/pretty_print.h +++ b/runtime/include/pretty_print.h @@ -1,18 +1,54 @@ #pragma once +#include #include #include #include +#include #include #include #include "lib.h" #include "lincheck.h" #include "logger.h" - +#include "stable_vector.h" using std::string; using std::to_string; -using FullHistoryWithThreads = std::vector, CoroutineStatus>>>; +struct CreateNewThreadHistoryInfo { + size_t created_thread_id; + std::string_view name; +}; + +using FullHistoryWithThreads = std::vector< + std::pair, CoroutineStatus, + CreateNewThreadHistoryInfo, WaitThreadInfo>>>; + +template +void Dfs(const StableVector& arr, std::vector& visited, size_t i, + std::vector& ans) { + ans.push_back(i); + visited[i] = true; + for (auto& u : arr[i].children) { + if (!visited[u]) { + Dfs(arr, visited, u, ans); + } + } +} + +template +void TopSort(const StableVector& arr, std::vector& ans) { + std::vector visited(arr.size(), false); + for (int i = 0; i < arr.size(); i++) { + if (!visited[i]) { + Dfs(arr, visited, i, ans); + } + } +} + +template +struct Overloads : Ts... { + using Ts::operator()...; +}; + struct PrettyPrinter { PrettyPrinter(size_t threads_num); @@ -32,14 +68,6 @@ struct PrettyPrinter { template void PrettyPrint(const std::vector>& result, Out_t& out) { - auto get_thread_num = [](const std::variant& v) { - // Crutch. - if (v.index() == 0) { - return get<0>(v).thread_id; - } - return get<1>(v).thread_id; - }; - int cell_width = 20; // Up it if necessary. Enough for now. auto print_separator = [&out, this, cell_width]() { @@ -80,7 +108,7 @@ struct PrettyPrinter { // Rows. for (const auto& i : result) { - int num = get_thread_num(i); + int num = std::visit([](auto& a) { return a.thread_id; }, i); out << "|"; for (int j = 0; j < num; ++j) { print_empty_cell(); @@ -88,24 +116,25 @@ struct PrettyPrinter { FitPrinter fp{out, cell_width}; fp.Out(" "); - if (i.index() == 0) { - auto inv = get<0>(i); - auto& task = inv.GetTask(); - fp.Out("[" + std::to_string(task->GetId()) + "] "); - fp.Out(std::string{task->GetName()}); - fp.Out("("); - const auto& args = task->GetStrArgs(); - for (int i = 0; i < args.size(); ++i) { - if (i > 0) { - fp.Out(", "); - } - fp.Out(args[i]); - } - fp.Out(")"); - } else { - auto resp = get<1>(i); - fp.Out("<-- " + to_string(resp.GetTask()->GetRetVal())); - } + std::visit( + Overloads{[&fp](const Invoke& inv) { + auto& task = inv.GetTask(); + fp.Out("[" + std::to_string(task->GetId()) + "] "); + fp.Out(std::string{task->GetName()}); + fp.Out("("); + const auto& args = task->GetStrArgs(); + for (int i = 0; i < args.size(); ++i) { + if (i > 0) { + fp.Out(", "); + } + fp.Out(args[i]); + } + fp.Out(")"); + }, + [&fp](const Response& resp) { + fp.Out("<-- " + to_string(resp.GetTask()->GetRetVal())); + }}, + i); assert(fp.rest > 0 && "increase cell_width in pretty printer"); print_spaces(fp.rest); out << "|"; @@ -121,9 +150,14 @@ struct PrettyPrinter { // Helps to debug full histories. template - void PrettyPrint(FullHistoryWithThreads& result, Out_t& out) { + void PrettyPrint(FullHistoryWithThreads& result, + const std::vector mapping, Out_t& out) { int cell_width = 20; // Up it if necessary. Enough for now. + std::vector inverse_mapping(mapping.size(), -1); + for (int i = 0; i < mapping.size(); i++) { + inverse_mapping[mapping[i]] = i; + } auto print_separator = [&out, this, cell_width]() { out << "*"; for (int i = 0; i < threads_num; ++i) { @@ -150,7 +184,7 @@ struct PrettyPrinter { for (int i = 0; i < threads_num; ++i) { int rest = cell_width - 1 /*T*/ - to_string(i).size(); print_spaces(rest / 2); - out << "T" << i; + out << "T" << mapping[i]; print_spaces(rest - rest / 2); out << "|"; } @@ -168,59 +202,79 @@ struct PrettyPrinter { std::vector co_depth(threads_num, 0); // Rows. for (const auto& i : result) { - int num = i.first; + int num = inverse_mapping[i.first]; FitPrinter fp{out, cell_width}; - if (i.second.index() == 0) { - auto act = std::get<0>(i.second); - auto base = act.get().get(); - if (index.find(base) == index.end()) { - int sz = index.size(); - index[base] = sz; - } - int length = std::to_string(index[base]).size(); - std::cout << index[base]; - assert(spaces - length >= 0); - print_spaces(7 - length); - out << "|"; - for (int j = 0; j < num; ++j) { - print_empty_cell(); - } - fp.Out(" "); - fp.Out(std::string{act.get()->GetName()}); - fp.Out("("); - const auto& args = act.get()->GetStrArgs(); - for (int i = 0; i < args.size(); ++i) { - if (i > 0) { - fp.Out(", "); - } - fp.Out(args[i]); - } - fp.Out(")"); - } else if (i.second.index() == 1) { - print_spaces(7); - out << "|"; - for (int j = 0; j < num; ++j) { - print_empty_cell(); - } - auto cor = std::get<1>(i.second); - auto print_formated_spaces = [&fp](int count) { - for (int i = 0; i < count; ++i) { - fp.Out(" "); - } - }; - if (cor.has_started) { - print_formated_spaces(co_depth[num] + 1); - fp.Out(">"); - co_depth[num]++; - } else { - print_formated_spaces(co_depth[num]); - fp.Out("<"); - co_depth[num]--; - } - fp.Out(cor.name); - // std::cerr << cor.name << "\n"; - assert(fp.rest > 0 && "increase cell_width in pretty printer"); - } + auto visitor = + Overloads{[&](std::reference_wrapper act) { + auto base = act.get().get(); + if (index.find(base) == index.end()) { + int sz = index.size(); + index[base] = sz; + } + int length = std::to_string(index[base]).size(); + out << index[base]; + assert(spaces - length >= 0); + print_spaces(7 - length); + out << "|"; + for (int j = 0; j < num; ++j) { + print_empty_cell(); + } + fp.Out(" "); + // std::cerr << "writing " << &act.get() << "\n"; + fp.Out(std::string{act.get()->GetName()}); + fp.Out("("); + const auto& args = act.get()->GetStrArgs(); + for (int i = 0; i < args.size(); ++i) { + if (i > 0) { + fp.Out(", "); + } + fp.Out(args[i]); + } + fp.Out(")"); + }, + [&](const CoroutineStatus& cor) { + print_spaces(7); + out << "|"; + for (int j = 0; j < num; ++j) { + print_empty_cell(); + } + auto print_formated_spaces = [&fp](int count) { + for (int i = 0; i < count; ++i) { + fp.Out(" "); + } + }; + if (cor.has_started) { + print_formated_spaces(co_depth[num] + 1); + fp.Out(">"); + co_depth[num]++; + } else { + print_formated_spaces(co_depth[num]); + fp.Out("<"); + co_depth[num]--; + } + fp.Out(cor.name); + }, + [&](const CreateNewThreadHistoryInfo& new_thread) { + print_spaces(7); + out << "|"; + for (int j = 0; j < num; ++j) { + print_empty_cell(); + } + fp.Out(std::string(" ->T") + + std::to_string(new_thread.created_thread_id)); + }, + [&](const WaitThreadInfo& wait_thread) { + print_spaces(7); + out << "|"; + for (int j = 0; j < num; ++j) { + print_empty_cell(); + } + fp.Out(std::string(" <-")); + } + + }; + std::visit(visitor, i.second); + assert(fp.rest > 0 && "increase cell_width in pretty printer"); print_spaces(fp.rest); out << "|"; diff --git a/runtime/include/scheduler.h b/runtime/include/scheduler.h index a6c22e2e..0ae223fe 100644 --- a/runtime/include/scheduler.h +++ b/runtime/include/scheduler.h @@ -1,13 +1,18 @@ #pragma once #include #include +#include #include #include #include +#include #include #include #include +#include #include +#include +#include #include "lib.h" #include "lincheck.h" @@ -17,6 +22,7 @@ #include "pretty_print.h" #include "scheduler_fwd.h" #include "stable_vector.h" +#include "value_wrapper.h" /// Generated by some strategy task, /// that may be not executed due to constraints of data structure @@ -485,23 +491,24 @@ struct StrategyScheduler : public SchedulerWithReplay { // TLAScheduler generates all executions satisfying some conditions. template struct TLAScheduler : Scheduler { - TLAScheduler(size_t max_tasks, size_t max_rounds, size_t threads_count, - size_t max_switches, size_t max_depth, - std::vector constructors, ModelChecker& checker, - PrettyPrinter& pretty_printer, std::function cancel_func) + TLAScheduler(size_t max_tasks, size_t max_rounds, + size_t initial_threads_count, size_t max_switches, + size_t max_depth, std::vector constructors, + ModelChecker& checker, std::function cancel_func) : max_tasks{max_tasks}, max_rounds{max_rounds}, max_switches{max_switches}, + initial_threads_count(initial_threads_count), constructors{std::move(constructors)}, checker{checker}, - pretty_printer{pretty_printer}, max_depth(max_depth), cancel(cancel_func) { - for (size_t i = 0; i < threads_count; ++i) { - threads.emplace_back(Thread{ - .id = i, - .tasks = StableVector{}, - }); + for (size_t i = 0; i < initial_threads_count; ++i) { + threads.emplace_back(Thread{.id = i, + .tasks = StableVector{}, + .children = {}, + .created_meta = {}, + .wait_cond = {}}); } }; @@ -516,8 +523,23 @@ struct TLAScheduler : Scheduler { struct Thread { size_t id; StableVector tasks; + std::vector children; + std::optional created_meta; + std::optional wait_cond; + }; + struct WaitId { + // this is the id from config + int base_id; + // this is the id of thread from where the task called + size_t thread; + // TODO add this ptr + bool operator==(const WaitId& oth) const = default; + }; + struct WaitHasher { + std::size_t operator()(const WaitId& id) const { + return (id.thread & 0xff) | ((id.base_id & 0xff) << 8); + } }; - // TLAScheduler enumerates all possible executions with finished max_tasks. // In fact, it enumerates tables (c = continue, f = finished): // *---------*---------*--------* @@ -535,22 +557,41 @@ struct TLAScheduler : Scheduler { Task* task{}; // Is true if the task was created at this step. bool is_new{}; + size_t thread_id; }; + // In structured concurrency so kind of termination will be safe for children + // threads + void TerminateThread(size_t i) { + Thread& thr = threads[i]; + for (size_t i = 0; i < thr.children.size(); i++) { + TerminateThread(thr.children[i]); + } + thr.children.clear(); + thr.wait_cond.reset(); + for (size_t j = 0; j < thr.tasks.size(); ++j) { + auto& task = thr.tasks[j]; + if (!task->IsReturned()) { + if (!task->Terminate()) { + // a new thread was spawned, so we need to create and finish him first + CreateNewThread(i); + virtual_thread_creation.reset(); + TerminateThread(threads.size() - 1); + } + } + } + } // Terminates all running tasks. // We do it in a dangerous way: in random order. // Actually, we assume obstruction free here. // cancel() func takes care for graceful shutdown void TerminateTasks() { cancel(); - for (size_t i = 0; i < threads.size(); ++i) { - for (size_t j = 0; j < threads[i].tasks.size(); ++j) { - auto& task = threads[i].tasks[j]; - if (!task->IsReturned()) { - task->Terminate(); - } - } + for (size_t i = 0; i < initial_threads_count; ++i) { + TerminateThread(i); } + started_thread_groups.clear(); + threads.resize(initial_threads_count); } // Replays all actions from 0 to the step_end. @@ -560,43 +601,128 @@ struct TLAScheduler : Scheduler { // In histories we store references, so there's no need to update it. state.Reset(); for (size_t step = 0; step < step_end; ++step) { - auto& frame = frames[step]; - auto task = frame.task; + Frame& frame = frames[step]; + Task* task = frame.task; assert(task); if (frame.is_new) { - // It was a new task. - // So restart it from the beginning with the same args. - *task = (*task)->Restart(&state); + // It was a new task.frame.created_meta + // So restart it from the beginning with the same + // args. + if (frame.thread_id < initial_threads_count) { + *task = (*task)->Restart(&state); + } else { + assert(false); + } } else { - // It was a not new task, hence, we recreated in early. + // It was a not new task, hence, we recreated in + // early. + if (frame.thread_id >= initial_threads_count) { + task = &threads[frame.thread_id].tasks.back(); + } } (*task)->Resume(); + if (virtual_thread_creation) { + CreateNewThread(frame.thread_id); + virtual_thread_creation.reset(); + } + AddWaitCond(frame.thread_id); + bool is_finished = (*task)->IsReturned(); + if (is_finished && frame.thread_id >= initial_threads_count) { + started_thread_groups[{ + .base_id = threads[frame.thread_id].created_meta->id, + .thread = threads[frame.thread_id].created_meta->parent}]--; + } + coroutine_status.reset(); + } + // Horrible but how to this better? + for (auto& v : full_history) { + if (v.first >= initial_threads_count) { + Task* new_task = &threads[v.first].tasks.back(); + std::visit(Overloads{[&new_task, &v](std::reference_wrapper& t) { + t = std::reference_wrapper(*new_task); + }, + [](auto& a) {}}, + v.second); + } + } + for (auto& v : sequential_history) { + size_t thread_id = std::visit([](auto& a) { return a.thread_id; }, v); + if (thread_id >= initial_threads_count) { + std::visit([this, thread_id]( + auto& a) { a.task = threads[thread_id].tasks.back(); }, + v); + } + } + } + + void CreateNewThread(size_t parent_id) { + assert(virtual_thread_creation.has_value()); + threads.emplace_back(Thread{.id = threads.size(), + .tasks = {}, + .children = {}, + .created_meta = virtual_thread_creation, + .wait_cond = {}}); + threads.back().created_meta->parent = parent_id; + threads[parent_id].children.push_back(threads.size() - 1); + started_thread_groups + .emplace( + WaitId{.base_id = virtual_thread_creation->id, .thread = parent_id}, + 0) + .first->second++; + auto& tasks = threads.back().tasks; + tasks.emplace_back(CreateSpawnedTask(*virtual_thread_creation)); + tasks.back()->Resume(); + tasks.back()->SetStrArgsAndName(virtual_thread_creation->name, + [](auto _) { return std::vector{}; }); + } + + void AddWaitCond(size_t thread_id) { + if (virtual_thread_wait) { + assert(thread_id < initial_threads_count); + assert(!threads[thread_id].wait_cond); + threads[thread_id].wait_cond = virtual_thread_wait; + virtual_thread_wait.reset(); + return; } - coroutine_status.reset(); } void UpdateFullHistory(size_t thread_id, Task& task, bool is_new) { + if (virtual_thread_creation) { + // this is will push us to the moment where coro is real started + CreateNewThread(thread_id); + full_history.emplace_back(thread_id, task); + // TODO: looks like we need a seperate function fo verifier + // verifier.UpdateState(coroutine_status->name, thread_id, + // true); + full_history.emplace_back( + thread_id, CreateNewThreadHistoryInfo{threads.size() - 1, + virtual_thread_creation->name}); + virtual_thread_creation.reset(); + return; + } + AddWaitCond(thread_id); if (coroutine_status.has_value()) { if (is_new) { assert(coroutine_status->has_started); full_history.emplace_back(thread_id, task); } - //To prevent cases like this - // +--------+--------+ - // | T1 | T2 | - // +--------+--------+ - // | | Recv | - // | Send | | - // | | >read | - // | >flush | | - // +--------+--------+ - verifier.UpdateState(coroutine_status->name, thread_id, coroutine_status->has_started); + // To prevent cases like this + // +--------+--------+ + // | T1 | T2 | + // +--------+--------+ + // | | Recv | + // | Send | | + // | | >read | + // | >flush | | + // +--------+--------+ + verifier.UpdateState(coroutine_status->name, thread_id, + coroutine_status->has_started); full_history.emplace_back(thread_id, coroutine_status.value()); coroutine_status.reset(); - } else { - verifier.UpdateState(task->GetName(), thread_id, is_new); - full_history.emplace_back(thread_id, task); + return; } + verifier.UpdateState(task->GetName(), thread_id, is_new); + full_history.emplace_back(thread_id, task); } // Resumes choosed task. // If task is finished and finished tasks == max_tasks, stops. @@ -630,8 +756,14 @@ struct TLAScheduler : Scheduler { UpdateFullHistory(thread_id, task, is_new); bool is_finished = task->IsReturned(); if (is_finished) { - finished_tasks++; verifier.OnFinished(TaskWithMetaData{task, false, thread.id}); + + if (thread.created_meta) { + started_thread_groups[{.base_id = thread.created_meta->id, + .thread = thread.created_meta->parent}]--; + } else { + finished_tasks++; + } auto result = task->GetRetVal(); sequential_history.emplace_back(Response(task, result, thread_id)); } @@ -645,7 +777,10 @@ struct TLAScheduler : Scheduler { } } else { log() << "run round: " << finished_rounds << "\n"; - pretty_printer.PrettyPrint(full_history, log()); + std::vector mapping; + TopSort(threads, mapping); + PrettyPrinter pretty_printer{threads.size()}; + pretty_printer.PrettyPrint(full_history, mapping, log()); log() << "===============================================\n\n"; log().flush(); // Stop, check if the the generated history is linearizable. @@ -655,65 +790,103 @@ struct TLAScheduler : Scheduler { std::make_pair(Scheduler::FullHistory{}, sequential_history)}; } if (finished_rounds == max_rounds) { - // It was the last round. + // It was the last round.q return {true, {}}; } } thread_id_history.pop_back(); - // Removing combination of start of task + coroutine start - if (full_history.back().second.index() == 1) { - auto& cor = std::get<1>(full_history.back().second); - auto& prev = full_history[full_history.size() - 2]; - int thread = full_history.back().first; - auto first_ind = - std::find_if(full_history.begin(), --full_history.end(), - [&thread](auto& a) { return a.first == thread; }); - if (cor.has_started && - std::distance(full_history.begin(), first_ind) == - full_history.size() - 2 && - prev.second.index() == 0) { - full_history.pop_back(); - } - } + auto visitor = Overloads{ + // Removing combination of start of task + coroutine start + [this](const CoroutineStatus& status) { + auto& cor = std::get<1>(full_history.back().second); + auto& prev = full_history[full_history.size() - 2]; + int thread = full_history.back().first; + auto first_ind = + std::find_if(full_history.begin(), --full_history.end(), + [&thread](auto& a) { return a.first == thread; }); + if (cor.has_started && + std::distance(full_history.begin(), first_ind) == + full_history.size() - 2 && + prev.second.index() == 0) { + full_history.pop_back(); + } + }, + [this](const CreateNewThreadHistoryInfo& created) { + full_history.pop_back(); + }, + [this](const WaitThreadInfo& wait) { full_history.pop_back(); }, + [](const std::reference_wrapper& a) {}}; + std::visit(visitor, full_history.back().second); full_history.pop_back(); if (is_finished) { - --finished_tasks; + if (thread_id < initial_threads_count) { + --finished_tasks; + } // resp. sequential_history.pop_back(); } if (is_new) { // inv. - --started_tasks; + if (thread_id < initial_threads_count) { + --started_tasks; + } sequential_history.pop_back(); } - return {false, {}}; } + Task CreateSpawnedTask(const CreatedThreadInfo& created) { + return Coro::New( + [&created](void*) -> ValueWrapper { + created.function(); + return void_v; + }, + &state, std::shared_ptr(), + [](std::shared_ptr) -> std::vector { return {}; }, + {}, -1); + } std::tuple RunStep(size_t step, size_t switches) { // Push frame to the stack. frames.emplace_back(Frame{}); auto& frame = frames.back(); - bool all_parked = true; // Pick next task. for (size_t i = 0; i < threads.size(); ++i) { - auto& thread = threads[i]; + Thread& thread = threads[i]; auto& tasks = thread.tasks; + if (thread.wait_cond) { + if (!std::all_of( + thread.wait_cond->wait_ids.begin(), + thread.wait_cond->wait_ids.end(), [this, &i](auto& child) { + auto it = started_thread_groups.find( + {.base_id = child, .thread = i}); + return it == started_thread_groups.end() || it->second == 0; + })) { + continue; + } + // full_history.emplace_back( + // WaitThreadInfo{.wait_ids = thread.wait_cond.value().wait_ids}); + thread.wait_cond.reset(); + } if (!tasks.empty() && !tasks.back()->IsReturned()) { if (tasks.back()->IsParked()) { continue; } all_parked = false; + auto& meta = thread.created_meta; + bool is_new = + meta ? !std::exchange((*thread.created_meta).has_started, true) + : false; if (!verifier.Verify(CreatedTaskMetaData{ - std::string{tasks.back()->GetName()}, false, i})) { + std::string{tasks.back()->GetName()}, is_new, i})) { continue; } // Task exists. frame.is_new = false; - auto [is_over, res] = ResumeTask(frame, step, switches, thread, false); + frame.thread_id = i; + auto [is_over, res] = ResumeTask(frame, step, switches, thread, is_new); if (is_over || res.has_value()) { return {is_over, res}; } @@ -724,6 +897,9 @@ struct TLAScheduler : Scheduler { } all_parked = false; + if (i >= initial_threads_count) { + continue; + } // Choose constructor to create task. bool stop = started_tasks == max_tasks; if (!stop && threads[i].tasks.size() < max_depth) { @@ -731,10 +907,11 @@ struct TLAScheduler : Scheduler { if (!verifier.Verify(CreatedTaskMetaData{cons.GetName(), true, i})) { continue; } - frame.is_new = true; - auto size_before = tasks.size(); tasks.emplace_back(cons.Build(&state, i, -1/* TODO: fix task id for tla, because it is Scheduler and not Strategy class for some reason */)); started_tasks++; + frame.is_new = true; + frame.thread_id = i; + auto size_before = tasks.size() - 1; auto [is_over, res] = ResumeTask(frame, step, switches, thread, true); if (is_over || res.has_value()) { return {is_over, res}; @@ -742,8 +919,8 @@ struct TLAScheduler : Scheduler { tasks.pop_back(); auto size_after = thread.tasks.size(); assert(size_before == size_after); - // As we can't return to the past in coroutine, we need to replay all - // tasks from the beginning. + // As we can't return to the past in coroutine, we need to replay + // all tasks from the beginning. Replay(step); } } @@ -754,9 +931,9 @@ struct TLAScheduler : Scheduler { return {false, {}}; } - PrettyPrinter& pretty_printer; size_t max_tasks; size_t max_rounds; + size_t initial_threads_count; size_t max_switches; size_t max_depth; @@ -775,4 +952,5 @@ struct TLAScheduler : Scheduler { StableVector frames; Verifier verifier; std::function cancel; + std::unordered_map started_thread_groups; }; diff --git a/runtime/include/stable_vector.h b/runtime/include/stable_vector.h index 6067fd8f..d21fb0f8 100644 --- a/runtime/include/stable_vector.h +++ b/runtime/include/stable_vector.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -49,7 +50,7 @@ struct StableVector { template requires std::constructible_from T &emplace_back(Args &&...args) { - const size_t index = 63 - __builtin_clzll(total_size + 1); + const size_t index = 63 - std::countl_zero(total_size + 1); if (((total_size + 1) & total_size) == 0 && !entities[index]) { entities[index] = ::new (static_cast(alignof(T))) type_t[1ULL << index]; @@ -62,7 +63,7 @@ struct StableVector { } void pop_back() noexcept { - const size_t index = 63 - __builtin_clzll(total_size); + const size_t index = 63 - std::countl_zero(total_size); if (((total_size - 1) & total_size) == 0 && entities[index + 1]) { ::operator delete[](std::exchange(entities[index + 1], nullptr), static_cast(alignof(T))); @@ -73,14 +74,14 @@ struct StableVector { } T &operator[](size_t i) noexcept { - const size_t index = 63 - __builtin_clzll(++i); + const size_t index = 63 - std::countl_zero(++i); const size_t internal_index = i ^ (1ULL << index); return *std::launder( reinterpret_cast(entities[index][internal_index])); } const T &operator[](size_t i) const noexcept { - const size_t index = 63 - __builtin_clzll(++i); + const size_t index = 63 - std::countl_zero(++i); const size_t internal_index = i ^ (1ULL << index); return *std::launder( reinterpret_cast(entities[index][internal_index])); diff --git a/runtime/include/verifying.h b/runtime/include/verifying.h index c4391dc4..737dc8c1 100644 --- a/runtime/include/verifying.h +++ b/runtime/include/verifying.h @@ -3,6 +3,7 @@ #include #include +#include #include "lib.h" #include "lincheck_recursive.h" @@ -143,7 +144,7 @@ std::unique_ptr MakeScheduler(ModelChecker &checker, Opts &opts, std::cout << "tla\n"; auto scheduler = std::make_unique>( opts.tasks, opts.rounds, opts.threads, opts.switches, opts.depth, - std::move(l), checker, pretty_printer, cancel); + std::move(l), checker, cancel); return scheduler; } default: { @@ -152,13 +153,19 @@ std::unique_ptr MakeScheduler(ModelChecker &checker, Opts &opts, } } -inline int TrapRun(std::unique_ptr &&scheduler, - PrettyPrinter &pretty_printer) { +inline int TrapRun(std::unique_ptr &&scheduler) { auto guard = SyscallTrapGuard{}; auto result = scheduler->Run(); if (result.has_value()) { std::cout << "non linearized:\n"; - pretty_printer.PrettyPrint(result.value().second, std::cout); + int count = 0; + for (auto& v : result->second) { + std::visit([&count](auto& a){ + count = std::max(count, a.thread_id); + }, v); + } + PrettyPrinter pretty_printer(count + 1); + pretty_printer.PrettyPrint(result->second, std::cout); return 1; } else { std::cout << "success!\n"; @@ -202,7 +209,7 @@ int Run(int argc, char *argv[]) { &Spec::cancel_t::Cancel); std::cout << "\n\n"; std::cout.flush(); - return TrapRun(std::move(scheduler), pretty_printer); + return TrapRun(std::move(scheduler)); } } // namespace ltest diff --git a/runtime/lib.cpp b/runtime/lib.cpp index 2292d76f..cf4d5c0e 100644 --- a/runtime/lib.cpp +++ b/runtime/lib.cpp @@ -1,6 +1,10 @@ #include "include/lib.h" +#include #include +#include +#include +#include #include #include #include @@ -13,6 +17,10 @@ Task this_coro{}; boost::context::fiber_context sched_ctx; std::optional coroutine_status; +std::optional virtual_thread_creation; + +std::optional virtual_thread_wait; + std::unordered_map futex_state{}; namespace ltest { @@ -63,19 +71,46 @@ extern "C" void CoroYield() { } extern "C" void CoroutineStatusChange(char* name, bool start) { - // assert(!coroutine_status.has_value()); + assert(!coroutine_status.has_value()); coroutine_status.emplace(name, start); CoroYield(); } -void CoroBase::Terminate() { +extern "C" void CreateNewVirtualThread(int id, void* func) { + virtual_thread_creation.emplace(reinterpret_cast(func), id); + CoroYield(); +} + +extern "C" void VirtualThreadStartPoint(char* name, char* args) { + assert(virtual_thread_creation); + virtual_thread_creation->name = name; + virtual_thread_creation->args = std::string(args); + CoroYield(); +} + +extern "C" void WaitForThread(int* ids, int size) { + std::vector vids(size); + std::copy(ids, ids + size, vids.begin()); + virtual_thread_wait.emplace(vids); + CoroYield(); +} + +bool CoroBase::Terminate() { int tries = 0; while (!IsReturned()) { ++tries; Resume(); + // we don't care about this while terminating + coroutine_status.reset(); + virtual_thread_wait.reset(); + // we couldn't process before the spawned thread is spawned and finished + if (virtual_thread_creation) { + return false; + } assert(tries < 10000000 && "coroutine is spinning too long, possible wrong terminating order"); } + return true; } void Token::Reset() { parked = false; } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4c5f637e..7d2b3efd 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(runtime) +# add_subdirectory(codegen) \ No newline at end of file diff --git a/test/codegen/CMakeLists.txt b/test/codegen/CMakeLists.txt new file mode 100644 index 00000000..62de3fc1 --- /dev/null +++ b/test/codegen/CMakeLists.txt @@ -0,0 +1,15 @@ + +find_package(Python3 REQUIRED) +find_program(LIT_EXECUTABLE NAMES lit lit.py) +if(NOT LIT_EXECUTABLE) + message(FATAL_ERROR "Could not find lit testing tool") +endif() + +add_test( + NAME "coyield" + COMMAND ${LIT_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR} --filter-out + ${CMAKE_CURRENT_SOURCE_DIR}/coyield/.* --verbose +) +set_tests_properties("coyield" + PROPERTIES + LABELS "llvm-pass") \ No newline at end of file diff --git a/test/codegen/coyield/lit.local.cfg b/test/codegen/coyield/lit.local.cfg new file mode 100644 index 00000000..d8e7a20d --- /dev/null +++ b/test/codegen/coyield/lit.local.cfg @@ -0,0 +1,10 @@ +import lit +config.name = 'CoYieldPass' +config.build_pass_root = os.path.join(config.lit_config_dir, '..', '..' ,'build', 'codegen') + +#todo add normal handling of config files +config.substitutions.append(('%check', f""" +clang++ %s -emit-llvm -Xclang -disable-llvm-passes -S -std=c++20 -o - +| opt -load-pass-plugin={config.build_pass_root}/libCoYieldPass.so --coroutine-file=%s.yml +-passes=coyield_insert -verify-each -stop-after=verify -S +| FileCheck %s""")) diff --git a/test/codegen/coyield/tests/bool_suspend.cpp b/test/codegen/coyield/tests/bool_suspend.cpp new file mode 100644 index 00000000..6fe29198 --- /dev/null +++ b/test/codegen/coyield/tests/bool_suspend.cpp @@ -0,0 +1,37 @@ +// RUN: %check +#include +struct SimpleAwaitable { + bool await_ready() const noexcept { + return false; + } + + bool await_suspend(std::coroutine_handle<> h) const noexcept { + h.resume(); + return false; + } + + void await_resume() const noexcept { + } +}; + +struct CoroTask { + struct promise_type { + CoroTask get_return_object() { return {}; } + std::suspend_never initial_suspend() { return {}; } + std::suspend_never final_suspend() noexcept { return {}; } + void return_void() {} + void unhandled_exception() {} + }; + auto operator co_await() const { return SimpleAwaitable{}; } + +}; + +CoroTask myCoroutine2() { + co_return; +} + +CoroTask myCoroutine() { + // CHECK: call void @CoroutineStatusChange(ptr [[name:@[0-9]+]], i1 true) + co_await myCoroutine2(); + // CHECK: call void @CoroutineStatusChange(ptr [[name]], i1 false) +} \ No newline at end of file diff --git a/test/codegen/coyield/tests/bool_suspend.cpp.yml b/test/codegen/coyield/tests/bool_suspend.cpp.yml new file mode 100644 index 00000000..aac6c521 --- /dev/null +++ b/test/codegen/coyield/tests/bool_suspend.cpp.yml @@ -0,0 +1,4 @@ +- Action: Coro + Name: test + Place: + Function: myCoroutine2() diff --git a/test/codegen/coyield/tests/simple_insert.cpp b/test/codegen/coyield/tests/simple_insert.cpp new file mode 100644 index 00000000..d8a44dfa --- /dev/null +++ b/test/codegen/coyield/tests/simple_insert.cpp @@ -0,0 +1,19 @@ +// RUN: %check +int g() { return 1; } +void based_on_pos(){ + +} + +void ignored() {} +int f() { + ignored(); + // CHECK: call void @CoroutineStatusChange(ptr [[name:@[0-9]+]], i1 true) + // CHECK-NEXT: %[[res:.*]] = call noundef i32 @_Z1gv() + return g(); + // CHECK: call void @CoroutineStatusChange(ptr [[name]], i1 false) + // CHECK-NEXT: ret i32 %[[res]] +} + +int ignored_f(){ + return g(); +} \ No newline at end of file diff --git a/test/codegen/coyield/tests/simple_insert.cpp.yml b/test/codegen/coyield/tests/simple_insert.cpp.yml new file mode 100644 index 00000000..17e146e0 --- /dev/null +++ b/test/codegen/coyield/tests/simple_insert.cpp.yml @@ -0,0 +1,10 @@ +- Action: Generic + Name: test + Place: + Parent: f() + Function: g() +- Action: Generic + Name: pos_test + Place: + Parent: based_on_pos() + Function: g() diff --git a/test/codegen/coyield/tests/thread_creation.cpp b/test/codegen/coyield/tests/thread_creation.cpp new file mode 100644 index 00000000..96067902 --- /dev/null +++ b/test/codegen/coyield/tests/thread_creation.cpp @@ -0,0 +1,50 @@ +// RUN: %check +#include +#include +static std::string str; + +extern "C" char* PrintInt(int i) { + str = std::to_string(i); + return str.data(); +}; + +struct Promise; +// NOLINTBEGIN(readability-identifier-naming) +struct SimpleAwaitable { + bool await_ready() const noexcept { return false; } + + bool await_suspend(std::coroutine_handle<> h) const noexcept { + h.resume(); + return true; + } + + void await_resume() const noexcept {} +}; +struct Coroutine : std::coroutine_handle { + using promise_type = ::Promise; + auto operator co_await() const { return SimpleAwaitable{}; } +}; + +struct Promise { + Coroutine get_return_object() { return {Coroutine::from_promise(*this)}; } + std::suspend_always initial_suspend() noexcept { return {}; } + std::suspend_always final_suspend() noexcept { return {}; } + void return_void() {} + void unhandled_exception() {} +}; +// NOLINTEND(readability-identifier-naming) + +// Let's omit realization for simplicity +struct Waiter { + void Add(Coroutine coro) {} + SimpleAwaitable Wait() { return {}; } +}; + +Coroutine DoWork(int i) { co_return; } +Coroutine Work(int i) { + Waiter w; + // CHECK: call + w.Add(DoWork(i)); + // CHECK: call + co_await w.Wait(); +} \ No newline at end of file diff --git a/test/codegen/coyield/tests/thread_creation.cpp.yml b/test/codegen/coyield/tests/thread_creation.cpp.yml new file mode 100644 index 00000000..0b08565b --- /dev/null +++ b/test/codegen/coyield/tests/thread_creation.cpp.yml @@ -0,0 +1,14 @@ +- Action: Spawn + CreationId: 0 + HasThis: True + Place: + Function: Waiter::Add(Coroutine) +- Action: Wait + WaitsFor: [0] + Place: + Function: Waiter::Wait() +- Action: SpawnedCoro + ArgsFun: PrintInt + Name: DoWork + Place: + Function: DoWork(int) \ No newline at end of file diff --git a/test/codegen/coyield/tests/void_suspend.cpp b/test/codegen/coyield/tests/void_suspend.cpp new file mode 100644 index 00000000..196d13cb --- /dev/null +++ b/test/codegen/coyield/tests/void_suspend.cpp @@ -0,0 +1,27 @@ +// RUN: %check +#include +struct SimpleAwaitable { + bool await_ready() const noexcept { return false; } + + void await_suspend(std::coroutine_handle<> h) const noexcept { h.resume(); } + + void await_resume() const noexcept {} +}; + +struct CoroTask { + struct promise_type { + CoroTask get_return_object() { return {}; } + std::suspend_never initial_suspend() { return {}; } + std::suspend_never final_suspend() noexcept { return {}; } + void return_void() {} + void unhandled_exception() {} + }; + auto operator co_await() const { return SimpleAwaitable{}; } +}; + +CoroTask myCoroutine2() { co_return; } +CoroTask myCoroutine() { + // CHECK: call void @CoroutineStatusChange(ptr [[name:@[0-9]+]], i1 true) + co_await myCoroutine2(); + // CHECK: call void @CoroutineStatusChange(ptr [[name]], i1 false) +} \ No newline at end of file diff --git a/test/codegen/coyield/tests/void_suspend.cpp.yml b/test/codegen/coyield/tests/void_suspend.cpp.yml new file mode 100644 index 00000000..aac6c521 --- /dev/null +++ b/test/codegen/coyield/tests/void_suspend.cpp.yml @@ -0,0 +1,4 @@ +- Action: Coro + Name: test + Place: + Function: myCoroutine2() diff --git a/test/codegen/lit.cfg b/test/codegen/lit.cfg new file mode 100644 index 00000000..d34efde5 --- /dev/null +++ b/test/codegen/lit.cfg @@ -0,0 +1,9 @@ +import lit + +config.name = 'Codegen' + +config.suffixes = ['.cpp'] + +config.test_format = lit.formats.ShTest() +config.lit_config_dir = os.path.dirname(os.path.abspath(__file__)) +config.test_source_root = None \ No newline at end of file diff --git a/test/runtime/stackfulltask_mock.h b/test/runtime/stackfulltask_mock.h index 32fd28a2..3cf8de86 100644 --- a/test/runtime/stackfulltask_mock.h +++ b/test/runtime/stackfulltask_mock.h @@ -18,5 +18,9 @@ class MockTask : public CoroBase { MOCK_METHOD(bool, IsSuspended, (), (const)); MOCK_METHOD(void, Terminate, (), ()); MOCK_METHOD(void, SetToken, (std::shared_ptr), ()); + MOCK_METHOD(void, SetStrArgsAndName, + (std::string_view, + std::function(std::shared_ptr)>), + (override)); virtual ~MockTask() { is_returned = true; } }; diff --git a/verifying/specs/communication.h b/verifying/specs/communication.h new file mode 100644 index 00000000..82a99e30 --- /dev/null +++ b/verifying/specs/communication.h @@ -0,0 +1,76 @@ +#include +#include +#include +#include +#include + +#include "../../runtime/include/verifying.h" +#include "support_coro.h" + +constexpr int writer_count = 2; +constexpr int start_id = 100; +namespace spec { + +struct CommunicationRef { + std::deque buf; + int id = start_id; + CommunicationRef() {} + CommunicationRef &operator=(const CommunicationRef &oth) { return *this; } + + void Send(int i) { + buf.push_back(id); + id++; + buf.push_back(i); + } + + void Receive() { + int m_id = buf.front(); + buf.pop_front(); + int message = buf.front(); + buf.pop_front(); + } + using MethodT = std::function; + static auto GetMethods() { + MethodT receive = [](CommunicationRef *l, void *args) { + l->Receive(); + return void_v; + }; + MethodT send = [](CommunicationRef *l, void *args) { + auto real_args = reinterpret_cast *>(args); + l->Send(std::get<0>(*real_args)); + return void_v; + }; + return std::map{{"Send", send}, {"Receive", receive}}; + } +}; + +struct UniqueArgsHash { + size_t operator()(const CommunicationRef &r) const { + int res = 0; + for (int elem : r.buf) { + res += elem; + } + return res; + } +}; +struct UniqueArgsEquals { + bool operator()(const CommunicationRef &lhs, + const CommunicationRef &rhs) const { + return lhs.buf == rhs.buf; + } +}; +struct UniqueArgsOptionsOverride { + static ltest::DefaultOptions GetOptions() { + return {.threads = writer_count + 1, + .tasks = writer_count + 1, + .switches = 100000000, + .rounds = 10000, + .depth = 1, + .forbid_all_same = false, + .verbose = false, + .strategy = "tla", + .weights = ""}; + } +}; + +} // namespace spec diff --git a/verifying/specs/support_coro.h b/verifying/specs/support_coro.h new file mode 100644 index 00000000..35b02873 --- /dev/null +++ b/verifying/specs/support_coro.h @@ -0,0 +1,33 @@ +#include +struct Promise; + + +// NOLINTBEGIN(readability-identifier-naming) +struct SimpleAwaitable { + bool await_ready() const noexcept { + return false; + } + + bool await_suspend(std::coroutine_handle<> h) const noexcept { + h.resume(); + return true; + } + + void await_resume() const noexcept { + } + +}; +struct Coroutine : std::coroutine_handle { + using promise_type = ::Promise; + auto operator co_await() const { return SimpleAwaitable{}; } +}; + + +struct Promise { + Coroutine get_return_object() { return {Coroutine::from_promise(*this)}; } + std::suspend_never initial_suspend() noexcept { return {}; } + std::suspend_always final_suspend() noexcept { return {}; } + void return_void() {} + void unhandled_exception() {} +}; +// NOLINTEND(readability-identifier-naming) \ No newline at end of file diff --git a/verifying/specs/unique_args.h b/verifying/specs/unique_args.h index afa13996..35a0c329 100644 --- a/verifying/specs/unique_args.h +++ b/verifying/specs/unique_args.h @@ -23,17 +23,18 @@ struct UniqueArgsRef { return {called == limit ? std::exchange(called, 0) : std::optional(), GetDefaultCompator>(), Print}; } - + void DoWork() { return; } using MethodT = std::function; static auto GetMethods() { MethodT get = [](UniqueArgsRef *l, void *args) { auto real_args = reinterpret_cast *>(args); return l->Get(std::get<0>(*real_args)); }; - - return std::map{ - {"Get", get}, + MethodT do_work = [](UniqueArgsRef *l, void *args) { + l->DoWork(); + return void_v; }; + return std::map{{"Get", get}, {"DoWork", do_work}}; } }; diff --git a/verifying/targets/CMakeLists.txt b/verifying/targets/CMakeLists.txt index 6e8d2023..8a3e8028 100644 --- a/verifying/targets/CMakeLists.txt +++ b/verifying/targets/CMakeLists.txt @@ -11,11 +11,13 @@ set (SOURCE_TARGET_LIST ) set (SOURCE_TARGET_WITHOUT_PLUGIN_LIST - unique_args.cpp ) set (SOURCE_TARGET_CO_LIST + unique_args.cpp counique_args.cpp + dynthreads_unique_args.cpp + nonlinear_communication.cpp ) foreach(source_name ${SOURCE_TARGET_LIST}) diff --git a/verifying/targets/counique_args.cpp b/verifying/targets/counique_args.cpp index b599640e..e890135a 100644 --- a/verifying/targets/counique_args.cpp +++ b/verifying/targets/counique_args.cpp @@ -5,33 +5,23 @@ #include #include "../specs/unique_args.h" +#include "../specs/support_coro.h" -struct Promise; - - -// NOLINTBEGIN(readability-identifier-naming) -struct Coroutine : std::coroutine_handle { - using promise_type = ::Promise; -}; - -struct Promise { - Coroutine get_return_object() { return {Coroutine::from_promise(*this)}; } - std::suspend_never initial_suspend() noexcept { return {}; } - std::suspend_always final_suspend() noexcept { return {}; } - void return_void() {} - void unhandled_exception() {} -}; -// NOLINTEND(readability-identifier-naming) static std::vector used(limit, false); static std::vector done(limit, false); -Coroutine CoFun(int i) { +Coroutine CoWork(int i) { done[i] = true; co_return; } -struct CoUniqueArgsTest { - CoUniqueArgsTest() {} + +Coroutine CoFun(int i) { + co_await CoWork(i); +} + +struct NonLinearCommunicationTest { + NonLinearCommunicationTest() {} ValueWrapper Get(size_t i) { assert(!used[i]); used[i] = true; @@ -60,10 +50,10 @@ auto GenerateArgs(size_t thread_num) { assert(false && "extra call"); } -target_method(GenerateArgs, int, CoUniqueArgsTest, Get, size_t); +target_method(GenerateArgs, int, NonLinearCommunicationTest, Get, size_t); using SpecT = - ltest::Spec; LTEST_ENTRYPOINT(SpecT); diff --git a/verifying/targets/counique_args.yml b/verifying/targets/counique_args.yml index 5cf938e5..75af9bec 100644 --- a/verifying/targets/counique_args.yml +++ b/verifying/targets/counique_args.yml @@ -1,3 +1,5 @@ -- Name: cofun - Coroutine: CoFun(int) +- Action: Coro + Name: cofun + Place: + Function: CoWork(int) \ No newline at end of file diff --git a/verifying/targets/dynthreads_unique_args.cpp b/verifying/targets/dynthreads_unique_args.cpp new file mode 100644 index 00000000..8552df18 --- /dev/null +++ b/verifying/targets/dynthreads_unique_args.cpp @@ -0,0 +1,104 @@ +#include +#include +#include +#include +#include +#include + +#include "../specs/unique_args.h" +#include "runtime/include/lib.h" + +static std::vector used(limit, false); +static std::vector state(limit, 0); +struct Promise; +// NOLINTBEGIN(readability-identifier-naming) +struct SimpleAwaitable { + bool await_ready() const noexcept { return false; } + + bool await_suspend(std::coroutine_handle<> h) const noexcept { + h.resume(); + return true; + } + + void await_resume() const noexcept {} +}; + +struct Coroutine : std::coroutine_handle { + using promise_type = ::Promise; + auto operator co_await() const { return SimpleAwaitable{}; } +}; + +struct Promise { + Coroutine get_return_object() { return {Coroutine::from_promise(*this)}; } + std::suspend_always initial_suspend() noexcept { return {}; } + std::suspend_never final_suspend() noexcept { return {}; } + void return_void() {} + void unhandled_exception() {} +}; +// NOLINTEND(readability-identifier-naming) +struct Waiter { + void Add(Coroutine&& coro) { + // list.push_back(coro); + coro.resume(); + } + SimpleAwaitable Wait() { return {}; } + std::vector list; +}; + +Coroutine DoWork(int i) { + state[i]++; + // std::cerr << "updated" << i << "\n"; + co_return; +} + +Coroutine Work(int i) { + Waiter w; + w.Add(DoWork(i)); + co_await w.Wait(); + assert(state[i] == 1); + co_return; +} + +static std::string str; + +extern "C" char* PrintInt(int i) { + str = std::to_string(i); + return str.data(); +}; +struct DynThreadsTest { + DynThreadsTest() {} + ValueWrapper Get(size_t i) { + assert(!used[i]); + used[i] = true; + bool last = std::count(used.begin(), used.end(), true) == limit; + auto coro = Work(i); + coro.resume(); + auto l = [this]() { + std::fill(used.begin(), used.end(), false); + return limit; + }; + return {last ? l() : std::optional(), + GetDefaultCompator>(), Print}; + } + void Reset() { + std::fill(used.begin(), used.end(), false); + std::fill(state.begin(), state.end(), false); + } +}; + +auto GenerateArgs(size_t thread_num) { + for (size_t i = 0; i < limit; i++) { + if (!used[i]) { + return ltest::generators::makeSingleArg(i); + } + } + assert(false && "extra call"); +} + +target_method(GenerateArgs, int, DynThreadsTest, Get, size_t); + +using SpecT = + ltest::Spec; + +LTEST_ENTRYPOINT(SpecT); diff --git a/verifying/targets/dynthreads_unique_args.yml b/verifying/targets/dynthreads_unique_args.yml new file mode 100644 index 00000000..9a8f6fd5 --- /dev/null +++ b/verifying/targets/dynthreads_unique_args.yml @@ -0,0 +1,14 @@ +- Action: Spawn + CreationId: 0 + HasThis: True + Place: + Function: Waiter::Add(Coroutine&&) +- Action: Wait + WaitsFor: [0] + Place: + Function: Waiter::Wait() +- Action: SpawnedCoro + ArgsFun: PrintInt + Name: DoWork + Place: + Function: DoWork(int) \ No newline at end of file diff --git a/verifying/targets/nonlinear_communication.cpp b/verifying/targets/nonlinear_communication.cpp new file mode 100644 index 00000000..917bf470 --- /dev/null +++ b/verifying/targets/nonlinear_communication.cpp @@ -0,0 +1,96 @@ +#include +#include +#include +#include +#include + +#include "../specs/communication.h" +#include "runtime/include/value_wrapper.h" + +static std::vector used(writer_count, false); + +static int write_id = start_id; +static int read_id = start_id; +struct NonLinearCommunicationTest { + std::deque buf; + NonLinearCommunicationTest() {} + // TODO better support return values + + Coroutine SendHeader(int r) { + buf.push_back(r); + co_return; + } + + Coroutine SendBody(int r) { + buf.push_back(r); + co_return; + } + + Coroutine SendImpl(int i) { + co_await SendHeader(write_id); + co_await SendBody(i); + write_id++; + co_return; + } + void Send(int i) { + assert(!used[i]); + used[i] = true; + SendImpl(i); + } + void Receive() { + for (int i = 0; i < writer_count; i++) { + buf.pop_front(); + int r = buf.front(); + buf.pop_front(); + assert(r < writer_count); + } + } + void Reset() { + std::fill(used.begin(), used.end(), false); + write_id = start_id; + buf.clear(); + } +}; + +auto GenerateArgs(size_t thread_num) { + for (size_t i = 0; i < writer_count; i++) { + if (!used[i]) { + return ltest::generators::makeSingleArg(i); + } + } + assert(false && "extra call"); +} +static constexpr std::string_view send_func = "Send"; +static constexpr std::string_view receive_func = "Receive"; +static constexpr size_t output_thread = writer_count; +class NoReadBeforeWrite { + public: + bool Verify(CreatedTaskMetaData task) { + // output from the pipe is the last thread + if (task.name == send_func && task.thread_id == output_thread) { + return false; + } + if (task.name == receive_func && task.thread_id != output_thread) { + return false; + } + // // no receive before send + if (task.name == receive_func && task.is_new && + write_id - start_id < writer_count) { + return false; + } + return true; + } + void OnFinished(TaskWithMetaData task) {} + void Reset() {} + void UpdateState(std::string_view coro_name, int thread_id, bool) {} +}; + +target_method(GenerateArgs, void, NonLinearCommunicationTest, Send, int); +target_method(ltest::generators::genEmpty, void, NonLinearCommunicationTest, + Receive); + +using SpecT = ltest::Spec; + +LTEST_ENTRYPOINT_CONSTRAINT(SpecT, NoReadBeforeWrite); diff --git a/verifying/targets/nonlinear_communication.yml b/verifying/targets/nonlinear_communication.yml new file mode 100644 index 00000000..00393559 --- /dev/null +++ b/verifying/targets/nonlinear_communication.yml @@ -0,0 +1,8 @@ +- Action: Coro + Name: SendHeader + Place: + Function: NonLinearCommunicationTest::SendHeader(int) +- Action: Coro + Name: SendBody + Place: + Function: NonLinearCommunicationTest::SendBody(int) \ No newline at end of file diff --git a/verifying/targets/unique_args.cpp b/verifying/targets/unique_args.cpp index ee53dc8d..b9c40f76 100644 --- a/verifying/targets/unique_args.cpp +++ b/verifying/targets/unique_args.cpp @@ -8,17 +8,19 @@ static std::vector used(limit, false); static std::vector done(limit, false); -struct CoUniqueArgsTest { - CoUniqueArgsTest() {} +void DoWork(int i){ + done[i] = true; +} +struct DynThreadsTest { + DynThreadsTest() {} ValueWrapper Get(size_t i) { assert(!used[i]); used[i] = true; - CoroYield(); + DoWork(i); auto l = [this]() { Reset(); return limit; }; - done[i] = true; return {std::count(done.begin(), done.end(), false) == 0 ? l() : std::optional(), @@ -39,10 +41,10 @@ auto GenerateArgs(size_t thread_num) { assert(false && "extra call"); } -target_method(GenerateArgs, int, CoUniqueArgsTest, Get, size_t); +target_method(GenerateArgs, int, DynThreadsTest, Get, size_t); using SpecT = - ltest::Spec; LTEST_ENTRYPOINT(SpecT); diff --git a/verifying/targets/unique_args.yml b/verifying/targets/unique_args.yml new file mode 100644 index 00000000..39f78f6b --- /dev/null +++ b/verifying/targets/unique_args.yml @@ -0,0 +1,4 @@ +- Action: Generic + Name: work + Place: + Function: DoWork(int) \ No newline at end of file