From f1f27e8d60ba89b9a0b76b0f566b3aefb84adb19 Mon Sep 17 00:00:00 2001 From: Chris Feger Date: Mon, 11 Mar 2024 17:50:47 -0400 Subject: [PATCH 1/3] Add AsyncOsi to client --- .../Extender/Client/ClientNetworking.cpp | 75 +++++ .../Extender/Server/ServerNetworking.cpp | 170 ++++++++++ .../Extender/Server/ServerNetworking.h | 4 +- BG3Extender/Extender/Shared/ExtenderNet.h | 3 +- .../Extender/Shared/ExtenderProtocol.proto | 41 +++ .../Extender/Shared/ThreadedExtenderState.inl | 2 +- BG3Extender/Extender/Shared/Utils.h | 6 +- BG3Extender/Lua/Client/LuaBindingClient.h | 9 + BG3Extender/Lua/Client/LuaClient.cpp | 4 + BG3Extender/Lua/Client/LuaOsirisBinding.h | 38 +++ BG3Extender/Lua/Helpers/LuaPush.h | 5 + BG3Extender/Lua/LuaBinding.cpp | 34 ++ BG3Extender/Lua/LuaBinding.h | 26 ++ BG3Extender/Lua/LuaOsiBridge.cpp | 302 +++++++++++++++--- BG3Extender/Lua/Server/LuaOsirisBinding.cpp | 4 +- BG3Extender/Lua/Server/LuaOsirisBinding.h | 1 - .../Lua/Shared/Proxies/LuaArrayProxy.h | 3 + BG3Extender/Osiris/OsirisExtender.cpp | 34 ++ BG3Extender/Osiris/OsirisExtender.h | 7 + 19 files changed, 713 insertions(+), 55 deletions(-) create mode 100644 BG3Extender/Lua/Client/LuaOsirisBinding.h diff --git a/BG3Extender/Extender/Client/ClientNetworking.cpp b/BG3Extender/Extender/Client/ClientNetworking.cpp index b4aef8e9d..f4098429e 100644 --- a/BG3Extender/Extender/Client/ClientNetworking.cpp +++ b/BG3Extender/Extender/Client/ClientNetworking.cpp @@ -58,6 +58,81 @@ void ExtenderProtocol::ProcessExtenderMessage(net::MessageContext& context, net: break; } + case net::MessageWrapper::kS2COsirisQueryResponse: + { + ecl::LuaClientPin pin(ecl::ExtensionState::Get()); + if (pin) + { + const auto& response = msg.s2c_osiris_query_response(); + + if (response.response_case() == net::MsgS2COsirisQueryResponse::ResponseCase::kError) + { + pin->ResolveOsirisFuture(response.responseid(), response.error()); + } + else if (!response.succeeded()) + { + pin->ResolveOsirisFuture(response.responseid(), "Unknown error occurred in query"); + } + else + { + Array>> respdata; + + respdata.Reallocate(response.results_size()); + + bool resultError = false; + + for (const auto& result : response.results()) + { + respdata.Add({}); + auto& currResp = *(respdata.begin() + respdata.size() - 1); + currResp.Reallocate(result.num_retvals()); + for (const auto& val : result.retvals()) + { + // Note: for some reason assigning to/emplacing to a variant crashes the game. + // So... don't + while (val.index() > currResp.size()) + { + currResp.Add({}); + } + switch (val.val_case()) + { + case std::remove_cvref_t::ValCase::kIntv: + currResp.Add({val.intv()}); + break; + case std::remove_cvref_t::ValCase::kNumv: + currResp.Add({val.numv()}); + break; + case std::remove_cvref_t::ValCase::kStrv: + currResp.Add({val.strv()}); + break; + } + } + + while (currResp.size() < result.num_retvals()) + { + currResp.Add({}); + } + + if (currResp.size() > result.num_retvals()) + { + resultError = true; + break; + } + } + + if (resultError) + { + pin->ResolveOsirisFuture(response.responseid(), "Malformed response: values not in order"); + } + else + { + pin->ResolveOsirisFuture(response.responseid(), respdata); + } + } + } + break; + } + default: OsiErrorS("Unknown extension message type received!"); } diff --git a/BG3Extender/Extender/Server/ServerNetworking.cpp b/BG3Extender/Extender/Server/ServerNetworking.cpp index 0162cd0a3..6bae18b24 100644 --- a/BG3Extender/Extender/Server/ServerNetworking.cpp +++ b/BG3Extender/Extender/Server/ServerNetworking.cpp @@ -4,6 +4,25 @@ BEGIN_NS(esv) +namespace +{ + void OsirisQueryErrorResponse(net::MessageContext& context, const net::MessageWrapper& msg, const char* error, ...) + { + auto message = gExtender->GetServer().GetNetworkManager().GetFreeMessage(); + auto response = message->GetMessage().mutable_s2c_osiris_query_response(); + static char errorformatbuf[1024]; + va_list args; + va_start(args, error); + vsnprintf(errorformatbuf, 1024, error, args); + va_end(args); + response->set_error(errorformatbuf); + response->set_responseid(msg.c2s_osiris_query().msgid()); + gExtender->GetServer().GetNetworkManager().Send(message, context.UserID); + } +} + +#define QUERY_ERROR(...) OsirisQueryErrorResponse(context, msg, __VA_ARGS__) + net::ProtocolResult ExtenderProtocol::ProcessMsg(void* unused, net::MessageContext* context, net::Message* msg) { auto base = ExtenderProtocolBase::ProcessMsg(unused, context, msg); @@ -41,11 +60,162 @@ void ExtenderProtocol::ProcessExtenderMessage(net::MessageContext& context, net: break; } + case net::MessageWrapper::kC2SOsirisQuery: + { + // FIXME - not yet supported + auto const& query = msg.c2s_osiris_query(); + + if (!gExtender->GetCurrentExtensionState()->GetLua() || gExtender->GetCurrentExtensionState()->GetLua()->RestrictionFlags & lua::State::RestrictOsiris) { + QUERY_ERROR("Attempted to read Osiris database in restricted context"); + } + + switch (query.type()) + { + case net::OsirisQueryType::OSIRIS_CALL: + QUERY_ERROR("Remote Osiris calls not (yet) supported!"); + OsiErrorS("Remote Osiris calls not (yet) supported!"); + break; + case net::OsirisQueryType::OSIRIS_DEFER: + QUERY_ERROR("Remote Osiris deferrals not (yet) supported!"); + OsiErrorS("Remote Osiris deferrals not (yet) supported!"); + break; + case net::OsirisQueryType::OSIRIS_DELETE: + QUERY_ERROR("Remote Osiris deletions not (yet) supported!"); + OsiErrorS("Remote Osiris deletions not (yet) supported!"); + break; + case net::OsirisQueryType::OSIRIS_GET: + ProcessOsirisGet(context, msg); + break; + } + break; + } + default: OsiErrorS("Unknown extension message type received!"); } } +void ExtenderProtocol::ProcessOsirisGet(net::MessageContext& context, net::MessageWrapper& msg) +{ + auto const& query = msg.c2s_osiris_query(); + + auto func = gExtender->GetServer().Osiris().LookupFunction(query.name().c_str(), query.num_args()); + if (func != nullptr && func->Signature->OutParamList.numOutParams() == 0) { + if (func->Type == FunctionType::Database) + { + auto db = func->Node.Get()->Database.Get(); + + auto head = db->Facts.Head; + auto current = head->Next; + + auto message = gExtender->GetServer().GetNetworkManager().GetFreeMessage(); + auto response = message->GetMessage().mutable_s2c_osiris_query_response(); + response->set_responseid(query.msgid()); + response->set_succeeded(true); + + std::vector*> args(query.num_args(), nullptr); + for (int i = 0; i < query.args_size(); i++) + { + args[query.args(i).index()] = &query.args(i); + } + + const auto check_match = [&args](TupleVec const& v){ + for (int i = 0; i < v.Size; i++) + { + switch (gExtender->GetServer().Osiris().GetBaseType((ValueType)v.Values[i].TypeId)) + { + case ValueType::Integer: + if (args[i] != nullptr && + !(args[i]->val_case() == args[i]->kIntv && args[i]->intv() == v.Values[i].Value.Int32)) + { + return false; + } + break; + + case ValueType::Integer64: + if (args[i] != nullptr && + !(args[i]->val_case() == args[i]->kIntv && args[i]->intv() == v.Values[i].Value.Int64)) + { + return false; + } + break; + + case ValueType::Real: + if (args[i] != nullptr && + !(args[i]->val_case() == args[i]->kNumv && abs(v.Values[i].Value.Float - args[i]->numv()) <= 0.00001f)) + { + return false; + } + break; + + case ValueType::String: + case ValueType::GuidString: + if (args[i] != nullptr && + !(args[i]->val_case() == args[i]->kStrv && args[i]->strv() == v.Values[i].Value.String)) + { + return false; + } + break; + } + } + return true; + }; + + while (current != head) { + if (check_match(current->Item)) { + auto result = response->add_results(); + result->set_num_retvals((uint32_t)args.size()); + for (int i = 0; i < args.size(); i++) + { + if (!args[i]) + { + auto v = result->add_retvals(); + v->set_index(i); + switch (gExtender->GetServer().Osiris().GetBaseType((ValueType)current->Item.Values[i].TypeId)) + { + case ValueType::Integer: + v->set_intv(current->Item.Values[i].Value.Int32); + break; + + case ValueType::Integer64: + v->set_intv(current->Item.Values[i].Value.Int64); + break; + + case ValueType::Real: + v->set_numv(current->Item.Values[i].Value.Float); + break; + + case ValueType::String: + case ValueType::GuidString: + v->set_strv(current->Item.Values[i].Value.String); + break; + default: + OsiErrorS("Unhandled Osi TypedValue"); + message->GetMessage().mutable_s2c_osiris_query_response()->set_succeeded(false); + message->GetMessage().mutable_s2c_osiris_query_response()->set_error("Unhandled Osi TypedValue"); + goto sendtime; // Used as a multi-stage break; + } + } + } + } + + current = current->Next; + } + + sendtime: + gExtender->GetServer().GetNetworkManager().Send(message, context.UserID); + } + else + { + QUERY_ERROR("Function '%s(%d)' is not a database", query.name().c_str(), query.args()); + } + } + else + { + QUERY_ERROR("No database named '%s(%d)' exists", query.name().c_str(), query.args()); + } +} + void NetworkManager::Reset() { peerVersions_.clear(); diff --git a/BG3Extender/Extender/Server/ServerNetworking.h b/BG3Extender/Extender/Server/ServerNetworking.h index 44959b144..b5d27476c 100644 --- a/BG3Extender/Extender/Server/ServerNetworking.h +++ b/BG3Extender/Extender/Server/ServerNetworking.h @@ -13,6 +13,8 @@ class ExtenderProtocol : public net::ExtenderProtocolBase protected: void ProcessExtenderMessage(net::MessageContext& context, net::MessageWrapper& msg) override; + + void ProcessOsirisGet(net::MessageContext& context, net::MessageWrapper& msg); }; class NetworkManager @@ -41,4 +43,4 @@ class NetworkManager std::unordered_map peerVersions_; }; -END_NS() \ No newline at end of file +END_NS() diff --git a/BG3Extender/Extender/Shared/ExtenderNet.h b/BG3Extender/Extender/Shared/ExtenderNet.h index fe178e4dd..3685a974e 100644 --- a/BG3Extender/Extender/Shared/ExtenderNet.h +++ b/BG3Extender/Extender/Shared/ExtenderNet.h @@ -12,8 +12,9 @@ class ExtenderMessage : public Message static constexpr uint32_t MaxPayloadLength = 0xfffff; static constexpr uint32_t VerInitial = 1; + static constexpr uint32_t VerClientOsirisQuery = VerInitial + 1; // Version of protocol, increment each time the protobuf changes - static constexpr uint32_t ProtoVersion = VerInitial; + static constexpr uint32_t ProtoVersion = VerClientOsirisQuery; ExtenderMessage(); ~ExtenderMessage() override; diff --git a/BG3Extender/Extender/Shared/ExtenderProtocol.proto b/BG3Extender/Extender/Shared/ExtenderProtocol.proto index e4e3cc017..6654ee49a 100644 --- a/BG3Extender/Extender/Shared/ExtenderProtocol.proto +++ b/BG3Extender/Extender/Shared/ExtenderProtocol.proto @@ -92,6 +92,45 @@ message MsgUserVars { repeated UserVar vars = 1; } +message OsirisVal { + uint32 index = 1; + oneof val { + string strv = 2; + int64 intv = 3; + float numv = 4; + } +} + +enum OsirisQueryType { + OSIRIS_GET = 0; + OSIRIS_CALL = 1; + OSIRIS_DELETE = 2; + OSIRIS_DEFER = 3; +} + +// Notifies the server that the client requested an Osiris query +message MsgC2SOsirisQuery { + uint32 msgid = 1; + string name = 2; + OsirisQueryType type = 3; + uint32 num_args = 4; // Full number of arguments, including null ones + repeated OsirisVal args = 5; +} + +message OsirisResult { + uint32 num_retvals = 1; + repeated OsirisVal retvals = 2; +} + +message MsgS2COsirisQueryResponse { + uint32 responseid = 1; + repeated OsirisResult results = 2; + oneof response { + bool succeeded = 3; + string error = 4; + } +} + message MessageWrapper { oneof msg { MsgPostLuaMessage post_lua = 1; @@ -100,5 +139,7 @@ message MessageWrapper { MsgS2CSyncStat s2c_sync_stat = 6; MsgS2CKick s2c_kick = 7; MsgUserVars user_vars = 8; + MsgC2SOsirisQuery c2s_osiris_query = 9; + MsgS2COsirisQueryResponse s2c_osiris_query_response = 10; } } diff --git a/BG3Extender/Extender/Shared/ThreadedExtenderState.inl b/BG3Extender/Extender/Shared/ThreadedExtenderState.inl index 6d30aa076..1706e412d 100644 --- a/BG3Extender/Extender/Shared/ThreadedExtenderState.inl +++ b/BG3Extender/Extender/Shared/ThreadedExtenderState.inl @@ -20,7 +20,7 @@ void ThreadedExtenderState::RemoveThread(DWORD threadId) void ThreadedExtenderState::EnqueueTask(std::function fun) { - threadTasks_.push(fun); + threadTasks_.push(std::move(fun)); } void ThreadedExtenderState::SubmitTaskAndWait(std::function fun) diff --git a/BG3Extender/Extender/Shared/Utils.h b/BG3Extender/Extender/Shared/Utils.h index 2b99e8349..957bd69fd 100644 --- a/BG3Extender/Extender/Shared/Utils.h +++ b/BG3Extender/Extender/Shared/Utils.h @@ -13,19 +13,19 @@ BEGIN_SE() #if !defined(OSI_NO_DEBUG_LOG) #define LuaError(msg) { \ std::stringstream ss; \ - ss << __FUNCTION__ "(): " msg; \ + ss << __FUNCTION__ "(): " << msg; \ LogLuaError(ss.str()); \ } #define OsiError(msg) { \ std::stringstream ss; \ - ss << __FUNCTION__ "(): " msg; \ + ss << __FUNCTION__ "(): " << msg; \ LogOsirisError(ss.str()); \ } #define OsiWarn(msg) { \ std::stringstream ss; \ - ss << __FUNCTION__ "(): " msg; \ + ss << __FUNCTION__ "(): " << msg; \ LogOsirisWarning(ss.str()); \ } diff --git a/BG3Extender/Lua/Client/LuaBindingClient.h b/BG3Extender/Lua/Client/LuaBindingClient.h index cccdb0117..4ad560aa0 100644 --- a/BG3Extender/Lua/Client/LuaBindingClient.h +++ b/BG3Extender/Lua/Client/LuaBindingClient.h @@ -4,6 +4,7 @@ #if defined(ENABLE_UI) #include #endif +#include namespace bg3se::ecl::lua { @@ -49,6 +50,14 @@ namespace bg3se::ecl::lua public: void Register(lua_State * L) override; void RegisterLib(lua_State * L) override; + + private: + static char const * const NameResolverMetatableName; + + void RegisterNameResolverMetatable(lua_State * L); + void CreateNameResolver(lua_State * L); + + static int LuaIndexResolverTable(lua_State* L); }; diff --git a/BG3Extender/Lua/Client/LuaClient.cpp b/BG3Extender/Lua/Client/LuaClient.cpp index d1b7e2820..268fefa4e 100644 --- a/BG3Extender/Lua/Client/LuaClient.cpp +++ b/BG3Extender/Lua/Client/LuaClient.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include "resource.h" #if defined(ENABLE_UI) #include @@ -33,6 +34,9 @@ LifetimePool& GetClientLifetimePool() void ExtensionLibraryClient::Register(lua_State * L) { ExtensionLibrary::Register(L); + AsyncOsiFunctionNameProxy::RegisterMetatable(L); + RegisterNameResolverMetatable(L); + CreateNameResolver(L); } void ExtensionLibraryClient::RegisterLib(lua_State * L) diff --git a/BG3Extender/Lua/Client/LuaOsirisBinding.h b/BG3Extender/Lua/Client/LuaOsirisBinding.h new file mode 100644 index 000000000..fc3683591 --- /dev/null +++ b/BG3Extender/Lua/Client/LuaOsirisBinding.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include +#include +#include + +BEGIN_NS(ecl::lua) + +using namespace bg3se::lua; + +class AsyncOsiFunctionNameProxy : public Userdata, public Callable +{ +public: + static char const * const MetatableName; + // Maximum number of OUT params that a query can return. + // (This setting determines how many function arities we'll check during name lookup) + static constexpr uint32_t MaxQueryOutParams = 6; + + static void PopulateMetatable(lua_State * L); + + AsyncOsiFunctionNameProxy(STDString const & name, class ClientState & state); + + int LuaCall(lua_State * L); + +private: + static std::atomic currentId; + STDString name_; + class ClientState & state_; + + static int LuaGet(lua_State * L); + static int LuaDelete(lua_State * L); + static int LuaDeferredNotification(lua_State * L); + bool BeforeCall(lua_State * L); +}; + +END_NS() diff --git a/BG3Extender/Lua/Helpers/LuaPush.h b/BG3Extender/Lua/Helpers/LuaPush.h index f7cdb7f8d..1949343ce 100644 --- a/BG3Extender/Lua/Helpers/LuaPush.h +++ b/BG3Extender/Lua/Helpers/LuaPush.h @@ -7,6 +7,11 @@ inline void push(lua_State* L, nullptr_t v) lua_pushnil(L); } +inline void push(lua_State* L, std::monostate v) +{ + lua_pushnil(L); +} + inline void push(lua_State* L, bool v) { lua_pushboolean(L, v ? 1 : 0); diff --git a/BG3Extender/Lua/LuaBinding.cpp b/BG3Extender/Lua/LuaBinding.cpp index 97d232a15..1547a8df1 100644 --- a/BG3Extender/Lua/LuaBinding.cpp +++ b/BG3Extender/Lua/LuaBinding.cpp @@ -177,6 +177,7 @@ namespace bg3se::lua void ExtensionLibrary::Register(lua_State * L) { RegisterLib(L); + AsyncOsiFuture::RegisterMetatable(L); } @@ -467,6 +468,39 @@ namespace bg3se::lua ThrowEvent("NetMessage", params); } + AsyncOsiFuture* State::CreateOsirisFuture(uint32_t id) + { + auto ret = Userdata::New(L); + futureManager_.emplace(id, RegistryEntry(L, -1)); + return ret; + } + + void State::ResolveOsirisFuture(uint32_t id, std::string_view err) + { + auto found = futureManager_.find(id); + if (found != futureManager_.end()) + { + found->second.Push(); + auto future = AsyncOsiFuture::CheckUserData(L, -1); + future->Resolve(err); + + futureManager_.erase(found); + } + } + + void State::ResolveOsirisFuture(uint32_t id, Array>>& results) + { + auto found = futureManager_.find(id); + if (found != futureManager_.end()) + { + found->second.Push(); + auto future = AsyncOsiFuture::CheckUserData(L, -1); + future->Resolve(results); + + futureManager_.erase(found); + } + } + STDString State::GetBuiltinLibrary(int resourceId) { auto resource = GetExeResource(resourceId); diff --git a/BG3Extender/Lua/LuaBinding.h b/BG3Extender/Lua/LuaBinding.h index 11d617fd0..3a1c979e2 100644 --- a/BG3Extender/Lua/LuaBinding.h +++ b/BG3Extender/Lua/LuaBinding.h @@ -179,6 +179,9 @@ namespace bg3se::lua virtual void OnUpdate(GameTime const& time); void OnStatsStructureLoaded(); void OnNetMessageReceived(STDString const& channel, STDString const& payload, UserId userId); + class AsyncOsiFuture* CreateOsirisFuture(uint32_t id); + void ResolveOsirisFuture(uint32_t id, std::string_view err); + void ResolveOsirisFuture(uint32_t id, Array>>& results); template bool CallExtRet(char const * func, uint32_t restrictions, std::tuple& ret, Args... args) @@ -236,6 +239,7 @@ namespace bg3se::lua CachedModVariableManager modVariableManager_; EntityComponentEventHooks entityHooks_; timer::TimerSystem timers_; + std::unordered_map futureManager_; void OpenLibs(); EventResult DispatchEvent(EventBase& evt, char const* eventName, bool canPreventAction, uint32_t restrictions); @@ -281,6 +285,28 @@ namespace bg3se::lua STDString Payload; UserId UserID; }; + + class AsyncOsiFuture : public Userdata + { + public: + static char const * const MetatableName; + + static void PopulateMetatable(lua_State * L); + + static AsyncOsiFuture* New(lua_State* L, uint32_t id) + { + return State::FromLua(L)->CreateOsirisFuture(id); + } + + void Resolve(std::string_view err); + void Resolve(Array>>& results); + private: + std::vector ThenList; + std::optional Else; + + static int LuaThen(lua_State * L); + static int LuaElse(lua_State * L); + }; } namespace bg3se::lua::stats diff --git a/BG3Extender/Lua/LuaOsiBridge.cpp b/BG3Extender/Lua/LuaOsiBridge.cpp index cc3d42b89..8aeab3b2d 100644 --- a/BG3Extender/Lua/LuaOsiBridge.cpp +++ b/BG3Extender/Lua/LuaOsiBridge.cpp @@ -37,15 +37,81 @@ Node* NodeRef::Get() const END_SE() -namespace bg3se::esv::lua +namespace bg3se::lua { - using namespace bg3se::lua; + char const * const AsyncOsiFuture::MetatableName = "AsyncOsiFuture"; - ValueType GetBaseType(ValueType type) + void AsyncOsiFuture::PopulateMetatable(lua_State * L) { - return (*gExtender->GetServer().Osiris().GetGlobals().Types)->ResolveAlias((uint16_t)type); + lua_newtable(L); + + lua_pushcfunction(L, &LuaThen); + lua_setfield(L, -2, "Then"); + + lua_pushcfunction(L, &LuaElse); + lua_setfield(L, -2, "Else"); + + lua_setfield(L, -2, "__index"); } + void AsyncOsiFuture::Resolve(std::string_view err) + { + ecl::LuaClientPin pin(ecl::ExtensionState::Get()); + if (pin) + { + auto L = pin->GetState(); + if (Else) + { + Else->Push(); + push(L, err); + CheckedCall(L, 1, "AsyncOsiFutureErrorCallback"); + } + } + } + + void AsyncOsiFuture::Resolve(Array>>& results) + { + ecl::LuaClientPin pin(ecl::ExtensionState::Get()); + if (pin) + { + auto L = pin->GetState(); + LifetimeStackPin lifetime(pin->GetStack()); + push(L, results, lifetime.GetLifetime()); + for (const auto& then : ThenList) + { + then.Push(); + lua_pushvalue(L, -2); + CheckedCall(L, 1, "AsyncOsiFutureSuccessCallback"); + } + } + } + + int AsyncOsiFuture::LuaThen(lua_State * L) + { + auto self = AsyncOsiFuture::CheckUserData(L, 1); + luaL_checktype(L, 2, LUA_TFUNCTION); + + self->ThenList.emplace_back(L, 2); + + lua_pushvalue(L, 1); + return 1; + } + + int AsyncOsiFuture::LuaElse(lua_State * L) + { + auto self = AsyncOsiFuture::CheckUserData(L, 1); + luaL_checktype(L, 2, LUA_TFUNCTION); + + self->Else.emplace(L, 2); + + return 0; + } +} + +namespace bg3se::esv::lua +{ + using namespace bg3se::lua; + int64_t LuaToInt(lua_State* L, int i, int type) { if (type == LUA_TNUMBER) { @@ -98,7 +164,7 @@ namespace bg3se::esv::lua return; } - switch (GetBaseType(osiType)) { + switch (gExtender->GetServer().Osiris().GetBaseType(osiType)) { case ValueType::Integer: tv.Value.Int32 = (int32_t)LuaToInt(L, i, type); break; @@ -150,7 +216,7 @@ namespace bg3se::esv::lua return; } - switch (GetBaseType(osiType)) { + switch (gExtender->GetServer().Osiris().GetBaseType(osiType)) { case ValueType::Integer: arg.Int32 = (int32_t)LuaToInt(L, i, type); break; @@ -192,7 +258,7 @@ namespace bg3se::esv::lua void OsiToLua(lua_State * L, OsiArgumentValue const & arg) { - switch (GetBaseType(arg.TypeId)) { + switch (gExtender->GetServer().Osiris().GetBaseType(arg.TypeId)) { case ValueType::None: lua_pushnil(L); break; @@ -222,7 +288,7 @@ namespace bg3se::esv::lua void OsiToLua(lua_State * L, TypedValue const & tv) { - switch (GetBaseType((ValueType)tv.TypeId)) { + switch (gExtender->GetServer().Osiris().GetBaseType((ValueType)tv.TypeId)) { case ValueType::None: lua_pushnil(L); break; @@ -250,41 +316,6 @@ namespace bg3se::esv::lua } } - uint32_t FunctionNameHash(char const * str) - { - uint32_t hash{ 0 }; - while (*str) { - hash = (*str++ | 0x20) + 129 * (hash % 4294967); - } - - return hash; - } - - Function const* LookupOsiFunction(STDString const& name, uint32_t arity) - { - auto functions = gExtender->GetServer().Osiris().GetGlobals().Functions; - if (!functions) { - return nullptr; - } - - OsiString sig(name); - sig += "/"; - sig += std::to_string(arity); - - auto hash = FunctionNameHash(name.c_str()) + arity; - auto func = (*functions)->Find(hash, sig); - if (func == nullptr - || ((*func)->Node.Id == 0 - && (*func)->Type != FunctionType::Call - && (*func)->Type != FunctionType::Query - && (*func)->Type != FunctionType::Event)) { - return nullptr; - } - - return *func; - }; - - bool OsiFunction::Bind(Function const * func, ServerState & state) { if (func->Type == FunctionType::Query @@ -463,7 +494,7 @@ namespace bg3se::esv::lua for (auto i = 0; i < tuple.Size; i++) { if (!lua_isnil(L, firstIndex + i)) { auto const & v = tuple.Values[i]; - switch (GetBaseType((ValueType)v.TypeId)) { + switch (gExtender->GetServer().Osiris().GetBaseType((ValueType)v.TypeId)) { case ValueType::Integer: if (v.Value.Int32 != lua_tointeger(L, firstIndex + i)) { return false; @@ -869,14 +900,14 @@ namespace bg3se::esv::lua } // Look for Call/Proc/Event/Query (number of OUT args == 0) - auto func = LookupOsiFunction(name_, arity); + auto func = gExtender->GetServer().Osiris().LookupFunction(name_, arity); if (func != nullptr && func->Signature->OutParamList.numOutParams() == 0) { return CreateFunctionMapping(arity, func); } for (uint32_t args = arity + 1; args < arity + MaxQueryOutParams; args++) { // Look for Query/UserQuery (number of OUT args > 0) - auto func = LookupOsiFunction(name_, args); + auto func = gExtender->GetServer().Osiris().LookupFunction(name_, args); if (func != nullptr) { auto outParams = func->Signature->OutParamList.numOutParams(); auto params = func->Signature->Params->Params.Size - outParams; @@ -1140,3 +1171,182 @@ namespace bg3se::esv::lua return STDString(ss.str()); } } + +namespace bg3se::ecl::lua +{ + using namespace bg3se::lua; + + char const * const ExtensionLibraryClient::NameResolverMetatableName = "AsyncOsiProxyNameResolver"; + + + void ExtensionLibraryClient::RegisterNameResolverMetatable(lua_State * L) + { + lua_register(L, NameResolverMetatableName, nullptr); + luaL_newmetatable(L, NameResolverMetatableName); // stack: mt + lua_pushcfunction(L, &LuaIndexResolverTable); // stack: mt, &LuaIndexResolverTable + lua_setfield(L, -2, "__index"); // mt.__index = &LuaIndexResolverTable; stack: mt + lua_pop(L, 1); // stack: mt + } + + void ExtensionLibraryClient::CreateNameResolver(lua_State * L) + { + lua_newtable(L); // stack: osi + luaL_setmetatable(L, NameResolverMetatableName); // stack: osi + lua_setglobal(L, "AsyncOsi"); // stack: - + } + + int ExtensionLibraryClient::LuaIndexResolverTable(lua_State * L) + { + luaL_checktype(L, 1, LUA_TTABLE); + auto name = luaL_checkstring(L, 2); + + LuaClientPin lua(ExtensionState::Get()); + AsyncOsiFunctionNameProxy::New(L, name, std::ref(lua.Get())); // stack: tab, name, proxy + + lua_pushvalue(L, 1); // stack: fun, tab + push(L, name); // stack: fun, tab, name + lua_pushvalue(L, -3); // stack: fun, tab, name, fun + lua_rawset(L, -3); // stack: fun + lua_pop(L, 1); + return 1; + } + + char const * const AsyncOsiFunctionNameProxy::MetatableName = "AsyncOsiFunctionNameProxy"; + std::atomic AsyncOsiFunctionNameProxy::currentId = 0; + + void AsyncOsiFunctionNameProxy::PopulateMetatable(lua_State * L) + { + lua_newtable(L); + + lua_pushcfunction(L, &LuaGet); + lua_setfield(L, -2, "Get"); + + lua_pushcfunction(L, &LuaDelete); + lua_setfield(L, -2, "Delete"); + + lua_pushcfunction(L, &LuaDeferredNotification); + lua_setfield(L, -2, "Defer"); + + lua_setfield(L, -2, "__index"); + } + + AsyncOsiFunctionNameProxy::AsyncOsiFunctionNameProxy(STDString const & name, ClientState & state) + : name_(name), state_(state) + {} + + bool AsyncOsiFunctionNameProxy::BeforeCall(lua_State * L) + { + if (state_.RestrictionFlags & State::RestrictOsiris) { + luaL_error(L, "Attempted to access Osiris function in restricted context"); + return false; + } + + return true; + } + + int AsyncOsiFunctionNameProxy::LuaCall(lua_State * L) + { + if (!BeforeCall(L)) return 1; + + // Note: might be a good idea to do this in the future. Protocol supports it + luaL_error(L, "Cannot perform queries or procedures or add to databases from the client"); + return 1; + } + + int AsyncOsiFunctionNameProxy::LuaGet(lua_State * L) + { + auto self = AsyncOsiFunctionNameProxy::CheckUserData(L, 1); + if (!self->BeforeCall(L)) return 1; + + const int args = lua_gettop(L) - 1; + + // Check that all the args are fine: + for (int i = 0; i < args; i++) + { + const int idx = i+2; + switch (lua_type(L, idx)) + { + case LUA_TNIL: // Null values don't get sent, but are fine + case LUA_TLIGHTUSERDATA: + case LUA_TNUMBER: + case LUA_TSTRING: + break; + case LUA_TBOOLEAN: + luaL_error(L, "Cannot use booleans in Osiris"); + return 1; + default: + luaL_error(L, "Improper type passed to an Osiris database as argument %d: %s", idx, luaL_typename(L, idx)); + return 1; + } + } + + const uint32_t msgid = currentId++; + + auto message = gExtender->GetClient().GetNetworkManager().GetFreeMessage(); + + auto query = message->GetMessage().mutable_c2s_osiris_query(); + + query->set_type(net::OsirisQueryType::OSIRIS_GET); + + query->set_name(self->name_.c_str()); + + query->set_num_args(args); + + query->set_msgid(msgid); + + // Actually use them: + for (int i = 0; i < args; i++) + { + const int idx = i+2; + switch (lua_type(L, idx)) + { + case LUA_TNIL: + break; + case LUA_TLIGHTUSERDATA: + { + auto handle = get(L, idx); + query->add_args()->set_intv((int64_t)handle.Handle); + break; + } + case LUA_TNUMBER: + if (lua_isinteger(L, idx)) + { + query->add_args()->set_intv(lua_tointeger(L, idx)); + } + else + { + query->add_args()->set_numv((float)lua_tonumber(L, idx)); + } + break; + case LUA_TSTRING: + query->add_args()->set_strv(lua_tostring(L, idx)); + break; + } + } + + gExtender->GetClient().GetNetworkManager().Send(message); + + AsyncOsiFuture::New(L, msgid); + return 1; + } + + int AsyncOsiFunctionNameProxy::LuaDelete(lua_State * L) + { + auto self = AsyncOsiFunctionNameProxy::CheckUserData(L, 1); + if (!self->BeforeCall(L)) return 1; + + // Note: might be a good idea to do this in the future. Protocol supports it + luaL_error(L, "Cannot delete from databases from the client"); + return 1; + } + + int AsyncOsiFunctionNameProxy::LuaDeferredNotification(lua_State * L) + { + auto self = AsyncOsiFunctionNameProxy::CheckUserData(L, 1); + if (!self->BeforeCall(L)) return 1; + + // Note: might be a good idea to do this in the future. Protocol (might) support it + luaL_error(L, "I'm not even sure what this is supposed to do but it probably isn't doable from client-side, and also isn't implemented server-side"); + return 1; + } +} diff --git a/BG3Extender/Lua/Server/LuaOsirisBinding.cpp b/BG3Extender/Lua/Server/LuaOsirisBinding.cpp index 903376008..9a9fcfd7b 100644 --- a/BG3Extender/Lua/Server/LuaOsirisBinding.cpp +++ b/BG3Extender/Lua/Server/LuaOsirisBinding.cpp @@ -252,10 +252,10 @@ void OsirisCallbackManager::StorySetMerging(bool isMerging) void OsirisCallbackManager::RegisterNodeHandler(OsirisHookSignature const& sig, SubscriptionId handlerId) { - auto func = LookupOsiFunction(sig.name, sig.arity); + auto func = gExtender->GetServer().Osiris().LookupFunction(sig.name, sig.arity); if (func != nullptr && func->Type == FunctionType::UserQuery) { // We need to find the backing node for the user query - func = LookupOsiFunction(sig.name + "__DEF__", sig.arity); + func = gExtender->GetServer().Osiris().LookupFunction(sig.name + "__DEF__", sig.arity); } if (func == nullptr) { diff --git a/BG3Extender/Lua/Server/LuaOsirisBinding.h b/BG3Extender/Lua/Server/LuaOsirisBinding.h index 57a5eb656..ced15bb3d 100644 --- a/BG3Extender/Lua/Server/LuaOsirisBinding.h +++ b/BG3Extender/Lua/Server/LuaOsirisBinding.h @@ -57,7 +57,6 @@ TypedValue * LuaToOsi(lua_State * L, int i, ValueType osiType, bool allowNil = f void LuaToOsi(lua_State * L, int i, OsiArgumentValue & arg, ValueType osiType, bool allowNil = false, bool reuseStrings = false); void OsiToLua(lua_State * L, OsiArgumentValue const & arg); void OsiToLua(lua_State * L, TypedValue const & tv); -Function const* LookupOsiFunction(STDString const& name, uint32_t arity); class OsiFunction { diff --git a/BG3Extender/Lua/Shared/Proxies/LuaArrayProxy.h b/BG3Extender/Lua/Shared/Proxies/LuaArrayProxy.h index 0407293d7..4d1a3768e 100644 --- a/BG3Extender/Lua/Shared/Proxies/LuaArrayProxy.h +++ b/BG3Extender/Lua/Shared/Proxies/LuaArrayProxy.h @@ -18,6 +18,9 @@ BY_VAL(ecs::EntityRef); BY_VAL(FixedString); BY_VAL(STDString); BY_VAL(STDWString); +BY_VAL(std::monostate); +BY_VAL(StringView); +BY_VAL(WStringView); #if defined(ENABLE_UI) BY_VAL(Noesis::String); BY_VAL(Noesis::Symbol); diff --git a/BG3Extender/Osiris/OsirisExtender.cpp b/BG3Extender/Osiris/OsirisExtender.cpp index 08084eb49..0e8811135 100644 --- a/BG3Extender/Osiris/OsirisExtender.cpp +++ b/BG3Extender/Osiris/OsirisExtender.cpp @@ -399,4 +399,38 @@ void OsirisExtender::BindCallbackManager(esv::lua::OsirisCallbackManager* mgr) } } +static uint32_t FunctionNameHash(char const * str) +{ + uint32_t hash{ 0 }; + while (*str) { + hash = (*str++ | 0x20) + 129 * (hash % 4294967); + } + + return hash; +} + +Function const* OsirisExtender::LookupFunction(const STDString& name, uint32_t arity) +{ + auto functions = GetGlobals().Functions; + if (!functions) { + return nullptr; + } + + OsiString sig(name); + sig += "/"; + sig += std::to_string(arity); + + auto hash = FunctionNameHash(name.c_str()) + arity; + auto func = (*functions)->Find(hash, sig); + if (func == nullptr + || ((*func)->Node.Id == 0 + && (*func)->Type != FunctionType::Call + && (*func)->Type != FunctionType::Query + && (*func)->Type != FunctionType::Event)) { + return nullptr; + } + + return *func; +} + } diff --git a/BG3Extender/Osiris/OsirisExtender.h b/BG3Extender/Osiris/OsirisExtender.h index 2894296f2..53d43ebaf 100644 --- a/BG3Extender/Osiris/OsirisExtender.h +++ b/BG3Extender/Osiris/OsirisExtender.h @@ -38,6 +38,13 @@ class OsirisExtender void LogWarning(std::string_view msg); void LogMessage(std::string_view msg); + Function const* LookupFunction(const STDString& name, uint32_t arity); + + inline ValueType GetBaseType(ValueType type) + { + return (*GetGlobals().Types)->ResolveAlias((uint16_t)type); + } + inline OsirisStaticGlobals const & GetGlobals() const { return wrappers_.Globals; From 09aa1c9ea609b8697e7ccb6ee0a815881c23205e Mon Sep 17 00:00:00 2001 From: Chris Feger Date: Wed, 13 Mar 2024 21:40:12 -0400 Subject: [PATCH 2/3] Add AsyncOsi to server --- BG3Extender/Lua/LuaBinding.cpp | 13 + BG3Extender/Lua/LuaBinding.h | 3 + BG3Extender/Lua/LuaOsiBridge.cpp | 412 +++++++++++++++++----- BG3Extender/Lua/Server/LuaBindingServer.h | 4 + BG3Extender/Lua/Server/LuaOsirisBinding.h | 78 +++- BG3Extender/Lua/Server/LuaServer.cpp | 3 + 6 files changed, 402 insertions(+), 111 deletions(-) diff --git a/BG3Extender/Lua/LuaBinding.cpp b/BG3Extender/Lua/LuaBinding.cpp index 1547a8df1..aea02abfa 100644 --- a/BG3Extender/Lua/LuaBinding.cpp +++ b/BG3Extender/Lua/LuaBinding.cpp @@ -501,6 +501,19 @@ namespace bg3se::lua } } + void State::ResolveOsirisFuture(uint32_t id, std::span results, bool success) + { + auto found = futureManager_.find(id); + if (found != futureManager_.end()) + { + found->second.Push(); + auto future = AsyncOsiFuture::CheckUserData(L, -1); + future->Resolve(results, success); + + futureManager_.erase(found); + } + } + STDString State::GetBuiltinLibrary(int resourceId) { auto resource = GetExeResource(resourceId); diff --git a/BG3Extender/Lua/LuaBinding.h b/BG3Extender/Lua/LuaBinding.h index 3a1c979e2..663f159f4 100644 --- a/BG3Extender/Lua/LuaBinding.h +++ b/BG3Extender/Lua/LuaBinding.h @@ -182,6 +182,7 @@ namespace bg3se::lua class AsyncOsiFuture* CreateOsirisFuture(uint32_t id); void ResolveOsirisFuture(uint32_t id, std::string_view err); void ResolveOsirisFuture(uint32_t id, Array>>& results); + void ResolveOsirisFuture(uint32_t id, std::span ref, bool success); template bool CallExtRet(char const * func, uint32_t restrictions, std::tuple& ret, Args... args) @@ -300,6 +301,8 @@ namespace bg3se::lua void Resolve(std::string_view err); void Resolve(Array>>& results); + + void Resolve(std::span result, bool succeeded); private: std::vector ThenList; std::optional Else; diff --git a/BG3Extender/Lua/LuaOsiBridge.cpp b/BG3Extender/Lua/LuaOsiBridge.cpp index 8aeab3b2d..ea5db38ea 100644 --- a/BG3Extender/Lua/LuaOsiBridge.cpp +++ b/BG3Extender/Lua/LuaOsiBridge.cpp @@ -56,7 +56,7 @@ namespace bg3se::lua void AsyncOsiFuture::Resolve(std::string_view err) { - ecl::LuaClientPin pin(ecl::ExtensionState::Get()); + LuaVirtualPin pin(gExtender->GetCurrentExtensionState()); if (pin) { auto L = pin->GetState(); @@ -71,7 +71,7 @@ namespace bg3se::lua void AsyncOsiFuture::Resolve(Array>>& results) { - ecl::LuaClientPin pin(ecl::ExtensionState::Get()); + LuaVirtualPin pin(gExtender->GetCurrentExtensionState()); if (pin) { auto L = pin->GetState(); @@ -86,6 +86,62 @@ namespace bg3se::lua } } + void AsyncOsiFuture::Resolve(std::span result, bool success) + { + LuaVirtualPin pin(gExtender->GetCurrentExtensionState()); + if (pin) + { + auto L = pin->GetState(); + if (success) + { + for (const auto& ref : result) + { + if (!ref.IsValid(L)) + { + // If not all references are valid, something weird happened. Just return + return; + } + } + for (const auto& then : ThenList) + { + then.Push(); + for (const auto& pref : result) + { + auto ref = pref.MakeRef(L); + ref.Push(L); + if (ref.Type() == RefType::Registry) + { + luaL_unref(L, LUA_REGISTRYINDEX, ref.Index()); + } + } + CheckedCall(L, result.size(), "AsyncOsiFutureSuccessCallback"); + } + lua_pop(L, 1); + } + else + { + if (Else) + { + Else->Push(); + if (result.size() == 1) + { + auto ref = result[0].MakeRef(L); + ref.Push(L); + if (ref.Type() == RefType::Registry) + { + luaL_unref(L, LUA_REGISTRYINDEX, ref.Index()); + } + } + else + { + lua_pushstring(L, "Unknown error"); + } + CheckedCall(L, 1, "AsyncOsiFutureErrorCallback"); + } + } + } + } + int AsyncOsiFuture::LuaThen(lua_State * L) { auto self = AsyncOsiFuture::CheckUserData(L, 1); @@ -342,35 +398,44 @@ namespace bg3se::esv::lua return true; } + std::pair BuildOsiError(lua_State * L, const char * fmt, ...) + { + va_list argp; + va_start(argp, fmt); + luaL_where(L, 1); + lua_pushvfstring(L, fmt, argp); + va_end(argp); + lua_concat(L, 2); + return {1, false}; + } + void OsiFunction::Unbind() { function_ = nullptr; } - int OsiFunction::LuaCall(lua_State * L) + std::pair OsiFunction::LuaCall(lua_State * L) { if (function_ == nullptr) { - return luaL_error(L, "Attempted to call an unbound Osiris function"); + return BuildOsiError(L, "Attempted to call an unbound Osiris function"); } int numArgs = lua_gettop(L); if (numArgs < 1) { - return luaL_error(L, "Called Osi function without 'self' argument?"); + return BuildOsiError(L, "Called Osi function without 'self' argument?"); } if (state_->RestrictionFlags & State::RestrictOsiris) { - return luaL_error(L, "Attempted to call Osiris function in restricted context"); + return BuildOsiError(L, "Attempted to call Osiris function in restricted context"); } switch (function_->Type) { case FunctionType::Call: - OsiCall(L); - return 0; + return OsiCall(L); case FunctionType::Event: case FunctionType::Proc: - OsiInsert(L, false); - return 0; + return OsiInsert(L, false); case FunctionType::Database: { @@ -380,13 +445,12 @@ namespace bg3se::esv::lua auto node = function_->Node.Get(); if (node) { if (node->IsDataNode()) { - OsiInsert(L, false); - return 0; + return OsiInsert(L, false); } else { return OsiUserQuery(L); } } else { - return luaL_error(L, "Function has no node!"); + return BuildOsiError(L, "Function has no node!"); } } @@ -399,27 +463,27 @@ namespace bg3se::esv::lua case FunctionType::SysCall: default: - return luaL_error(L, "Cannot call function of type %d", function_->Type); + return BuildOsiError(L, "Cannot call function of type %d", function_->Type); } } - int OsiFunction::LuaGet(lua_State * L) + std::pair OsiFunction::LuaGet(lua_State * L) { if (!IsBound()) { - return luaL_error(L, "Attempted to read an unbound Osiris database"); + return BuildOsiError(L, "Attempted to read an unbound Osiris database"); } if (!IsDB()) { - return luaL_error(L, "Attempted to read function that's not a database"); + return BuildOsiError(L, "Attempted to read function that's not a database"); } int numArgs = lua_gettop(L); if (numArgs < 1) { - return luaL_error(L, "Read Osi database without 'self' argument?"); + return BuildOsiError(L, "Read Osi database without 'self' argument?"); } if (state_->RestrictionFlags & State::RestrictOsiris) { - return luaL_error(L, "Attempted to read Osiris database in restricted context"); + return BuildOsiError(L, "Attempted to read Osiris database in restricted context"); } auto db = function_->Node.Get()->Database.Get(); @@ -440,52 +504,50 @@ namespace bg3se::esv::lua current = current->Next; } - return 1; + return {1, true}; } - int OsiFunction::LuaDelete(lua_State * L) + std::pair OsiFunction::LuaDelete(lua_State * L) { if (!IsBound()) { - return luaL_error(L, "Attempted to delete from an unbound Osiris database"); + return BuildOsiError(L, "Attempted to delete from an unbound Osiris database"); } if (!IsDB()) { - return luaL_error(L, "Attempted to delete from function that's not a database"); + return BuildOsiError(L, "Attempted to delete from function that's not a database"); } int numArgs = lua_gettop(L); if (numArgs < 1) { - return luaL_error(L, "Delete from Osi database without 'self' argument?"); + return BuildOsiError(L, "Delete from Osi database without 'self' argument?"); } if (state_->RestrictionFlags & State::RestrictOsiris) { - return luaL_error(L, "Attempted to delete from Osiris database in restricted context"); + return BuildOsiError(L, "Attempted to delete from Osiris database in restricted context"); } - OsiInsert(L, true); - return 0; + return OsiInsert(L, true); } - int OsiFunction::LuaDeferredNotification(lua_State * L) + std::pair OsiFunction::LuaDeferredNotification(lua_State * L) { if (function_ == nullptr) { - return luaL_error(L, "Attempted to call an unbound Osiris function"); + return BuildOsiError(L, "Attempted to call an unbound Osiris function"); } int numArgs = lua_gettop(L); if (numArgs < 1) { - return luaL_error(L, "Called Osi function without 'self' argument?"); + return BuildOsiError(L, "Called Osi function without 'self' argument?"); } if (state_->RestrictionFlags & State::RestrictOsiris) { - return luaL_error(L, "Attempted to call Osiris function in restricted context"); + return BuildOsiError(L, "Attempted to call Osiris function in restricted context"); } if (function_->Type == FunctionType::Event) { - OsiDeferredNotification(L); - return 0; + return OsiDeferredNotification(L); } else { - return luaL_error(L, "Cannot queue deferred events on function of type %d", function_->Type); + return BuildOsiError(L, "Cannot queue deferred events on function of type %d", function_->Type); } } @@ -555,12 +617,12 @@ namespace bg3se::esv::lua } } - void OsiFunction::OsiCall(lua_State * L) + std::pair OsiFunction::OsiCall(lua_State * L) { auto funcArgs = function_->Signature->Params->Params.Size; int numArgs = lua_gettop(L); if (numArgs - 1 != funcArgs) { - luaL_error(L, "Incorrect number of arguments for '%s'; expected %d, got %d", + return BuildOsiError(L, "Incorrect number of arguments for '%s'; expected %d, got %d", function_->Signature->Name, funcArgs, numArgs - 1); } @@ -576,9 +638,11 @@ namespace bg3se::esv::lua } gExtender->GetServer().Osiris().GetWrappers().Call.CallWithHooks(function_->GetHandle(), funcArgs == 0 ? nullptr : args.Args()); + + return {0, true}; } - void OsiFunction::OsiDeferredNotification(lua_State * L) + std::pair OsiFunction::OsiDeferredNotification(lua_State * L) { /*auto funcArgs = function_->Signature->Params->Params.Size; int numArgs = lua_gettop(L); @@ -614,19 +678,20 @@ namespace bg3se::esv::lua story->Manager->QueueNotification(notification);*/ OsiError("FIXME: OsiDeferredNotification not implemented yet!"); + return {0, false}; } - void OsiFunction::OsiInsert(lua_State * L, bool deleteTuple) + std::pair OsiFunction::OsiInsert(lua_State * L, bool deleteTuple) { auto funcArgs = function_->Signature->Params->Params.Size; int numArgs = lua_gettop(L); if (numArgs - 1 != funcArgs) { - luaL_error(L, "Incorrect number of arguments for '%s'; expected %d, got %d", + return BuildOsiError(L, "Incorrect number of arguments for '%s'; expected %d, got %d", function_->Signature->Name, funcArgs, numArgs - 1); } if (function_->Node.Id == 0) { - luaL_error(L, "Function has no node"); + return BuildOsiError(L, "Function has no node"); } OsiArgumentListPin tvs(state_->Osiris().GetTypedValuePool(), (uint32_t)funcArgs); @@ -653,9 +718,11 @@ namespace bg3se::esv::lua } else { node->InsertTuple(&tuple); } + + return {0, true}; } - int OsiFunction::OsiQuery(lua_State * L) + std::pair OsiFunction::OsiQuery(lua_State * L) { auto outParams = function_->Signature->OutParamList.numOutParams(); auto numParams = function_->Signature->Params->Params.Size; @@ -663,7 +730,7 @@ namespace bg3se::esv::lua int numArgs = lua_gettop(L); if (numArgs - 1 != inParams) { - return luaL_error(L, "Incorrect number of IN arguments for '%s'; expected %d, got %d", + return BuildOsiError(L, "Incorrect number of IN arguments for '%s'; expected %d, got %d", function_->Signature->Name, inParams, numArgs - 1); } @@ -688,7 +755,7 @@ namespace bg3se::esv::lua bool handled = gExtender->GetServer().Osiris().GetWrappers().Query.CallWithHooks(function_->GetHandle(), numParams == 0 ? nullptr : args.Args()); if (outParams == 0) { push(L, handled); - return 1; + return {1, true}; } else { if (handled) { for (uint32_t i = 0; i < numParams; i++) { @@ -702,11 +769,11 @@ namespace bg3se::esv::lua } } - return outParams; + return {outParams, true}; } } - int OsiFunction::OsiUserQuery(lua_State * L) + std::pair OsiFunction::OsiUserQuery(lua_State * L) { auto outParams = function_->Signature->OutParamList.numOutParams(); auto numParams = function_->Signature->Params->Params.Size; @@ -714,7 +781,7 @@ namespace bg3se::esv::lua int numArgs = lua_gettop(L); if (numArgs - 1 != inParams) { - return luaL_error(L, "Incorrect number of IN arguments for '%s'; expected %d, got %d", + return BuildOsiError(L, "Incorrect number of IN arguments for '%s'; expected %d, got %d", function_->Signature->Name, inParams, numArgs - 1); } @@ -759,10 +826,10 @@ namespace bg3se::esv::lua retType = retType->Next; } - return outParams; + return {outParams, true}; } else { push(L, true); - return 1; + return {1, true}; } } else { if (outParams > 0) { @@ -770,47 +837,28 @@ namespace bg3se::esv::lua lua_pushnil(L); } - return outParams; + return {outParams, true}; } else { push(L, false); - return 1; + return {1, true}; } } } - - char const * const OsiFunctionNameProxy::MetatableName = "OsiFunctionNameProxy"; - - void OsiFunctionNameProxy::PopulateMetatable(lua_State * L) - { - lua_newtable(L); - - lua_pushcfunction(L, &LuaGet); - lua_setfield(L, -2, "Get"); - - lua_pushcfunction(L, &LuaDelete); - lua_setfield(L, -2, "Delete"); - - lua_pushcfunction(L, &LuaDeferredNotification); - lua_setfield(L, -2, "Defer"); - - lua_setfield(L, -2, "__index"); - } - - OsiFunctionNameProxy::OsiFunctionNameProxy(STDString const & name, ServerState & state) + OsiFunctionNameProxyBase::OsiFunctionNameProxyBase(STDString const & name, ServerState & state) : name_(name), state_(state), generationId_(state_.Osiris().GenerationId()) {} - void OsiFunctionNameProxy::UnbindAll() + void OsiFunctionNameProxyBase::UnbindAll() { functions_.clear(); } - bool OsiFunctionNameProxy::BeforeCall(lua_State * L) + bool OsiFunctionNameProxyBase::BeforeCall(lua_State * L) { if (state_.RestrictionFlags & State::RestrictOsiris) { - luaL_error(L, "Attempted to access Osiris function in restricted context"); + BuildOsiError(L, "Attempted to access Osiris function in restricted context"); return false; } @@ -823,76 +871,73 @@ namespace bg3se::esv::lua return true; } - int OsiFunctionNameProxy::LuaCall(lua_State * L) + std::pair OsiFunctionNameProxyBase::DoLuaCall(lua_State * L) { - if (!BeforeCall(L)) return 1; + if (!BeforeCall(L)) return {1, false}; auto arity = (uint32_t)lua_gettop(L) - 1; auto func = TryGetFunction(arity); if (func == nullptr) { - return luaL_error(L, "No function named '%s' exists that can be called with %d parameters.", + return BuildOsiError(L, "No function named '%s' exists that can be called with %d parameters.", name_.c_str(), arity); } return func->LuaCall(L); } - int OsiFunctionNameProxy::LuaGet(lua_State * L) + std::pair OsiFunctionNameProxyBase::DoLuaGet(OsiFunctionNameProxyBase * self, lua_State * L) { - auto self = OsiFunctionNameProxy::CheckUserData(L, 1); - if (!self->BeforeCall(L)) return 1; + if (!self->BeforeCall(L)) return {1, false}; auto arity = (uint32_t)lua_gettop(L) - 1; auto func = self->TryGetFunction(arity); if (func == nullptr) { - return luaL_error(L, "No database named '%s(%d)' exists", self->name_.c_str(), arity); + return BuildOsiError(L, "No database named '%s(%d)' exists", self->name_.c_str(), arity); } if (!func->IsDB()) { - return luaL_error(L, "Function '%s(%d)' is not a database", self->name_.c_str(), arity); + return BuildOsiError(L, "Function '%s(%d)' is not a database", self->name_.c_str(), arity); } return func->LuaGet(L); } - int OsiFunctionNameProxy::LuaDelete(lua_State * L) + std::pair OsiFunctionNameProxyBase::DoLuaDelete(OsiFunctionNameProxyBase * self, lua_State * L) { - auto self = OsiFunctionNameProxy::CheckUserData(L, 1); - if (!self->BeforeCall(L)) return 1; + if (!self->BeforeCall(L)) return {1, false}; auto arity = (uint32_t)lua_gettop(L) - 1; auto func = self->TryGetFunction(arity); if (func == nullptr) { - return luaL_error(L, "No database named '%s(%d)' exists", self->name_.c_str(), arity); + return BuildOsiError(L, "No database named '%s(%d)' exists", self->name_.c_str(), arity); } if (!func->IsDB()) { - return luaL_error(L, "Function '%s(%d)' is not a database", self->name_.c_str(), arity); + return BuildOsiError(L, "Function '%s(%d)' is not a database", self->name_.c_str(), arity); } return func->LuaDelete(L); } - int OsiFunctionNameProxy::LuaDeferredNotification(lua_State * L) + std::pair OsiFunctionNameProxyBase::DoLuaDeferredNotification(OsiFunctionNameProxyBase * self, lua_State * L) { - auto self = OsiFunctionNameProxy::CheckUserData(L, 1); - if (!self->BeforeCall(L)) return 1; + if (!self->BeforeCall(L)) return {1, false}; auto arity = (uint32_t)lua_gettop(L) - 1; auto func = self->TryGetFunction(arity); if (func == nullptr) { - return luaL_error(L, "No function named '%s' exists that can be called with %d parameters.", + return BuildOsiError(L, "No function named '%s' exists that can be called with %d parameters.", self->name_.c_str(), arity); } return func->LuaDeferredNotification(L); } - OsiFunction * OsiFunctionNameProxy::TryGetFunction(uint32_t arity) + OsiFunction * OsiFunctionNameProxyBase::TryGetFunction(uint32_t arity) { if (functions_.size() > arity && functions_[arity].IsBound()) { @@ -920,7 +965,7 @@ namespace bg3se::esv::lua return nullptr; } - OsiFunction * OsiFunctionNameProxy::CreateFunctionMapping(uint32_t arity, Function const * func) + OsiFunction * OsiFunctionNameProxyBase::CreateFunctionMapping(uint32_t arity, Function const * func) { if (functions_.size() <= arity) { functions_.resize(arity + 1); @@ -933,6 +978,154 @@ namespace bg3se::esv::lua } } + char const * const OsiFunctionNameProxy::MetatableName = "OsiFunctionNameProxy"; + + void OsiFunctionNameProxy::PopulateMetatable(lua_State * L) + { + lua_newtable(L); + + lua_pushcfunction(L, &LuaGet); + lua_setfield(L, -2, "Get"); + + lua_pushcfunction(L, &LuaDelete); + lua_setfield(L, -2, "Delete"); + + lua_pushcfunction(L, &LuaDeferredNotification); + lua_setfield(L, -2, "Defer"); + + lua_setfield(L, -2, "__index"); + } + + int OsiFunctionNameProxy::LuaCall(lua_State * L) + { + auto [num, success] = DoLuaCall(L); + if (!success) + { + lua_error(L); + return 1; + } + + return num; + } + + int OsiFunctionNameProxy::LuaGet(lua_State * L) + { + auto [num, success] = DoLuaGet(OsiFunctionNameProxy::CheckUserData(L, 1), L); + if (!success) + { + lua_error(L); + return 1; + } + + return num; + } + + int OsiFunctionNameProxy::LuaDelete(lua_State * L) + { + auto [num, success] = DoLuaDelete(OsiFunctionNameProxy::CheckUserData(L, 1), L); + if (!success) + { + lua_error(L); + return 1; + } + + return num; + } + + int OsiFunctionNameProxy::LuaDeferredNotification(lua_State * L) + { + auto [num, success] = DoLuaDeferredNotification(OsiFunctionNameProxy::CheckUserData(L, 1), L); + if (!success) + { + lua_error(L); + return num; + } + + return num; + } + + char const * const AsyncOsiFunctionNameProxy::MetatableName = "AsyncOsiFunctionNameProxy"; + std::atomic AsyncOsiFunctionNameProxy::currQueryId = 0; + + void AsyncOsiFunctionNameProxy::PopulateMetatable(lua_State * L) + { + lua_newtable(L); + + lua_pushcfunction(L, &LuaGet); + lua_setfield(L, -2, "Get"); + + lua_pushcfunction(L, &LuaDelete); + lua_setfield(L, -2, "Delete"); + + lua_pushcfunction(L, &LuaDeferredNotification); + lua_setfield(L, -2, "Defer"); + + lua_setfield(L, -2, "__index"); + } + + int AsyncOsiFunctionNameProxy::SharedLuaLogic(const std::pair& result, lua_State* L) + { + const auto& [num, success] = result; + uint32_t queryid = currQueryId++; + + if (!success) + { + PersistentRegistryEntry errEntry(L, -1); + lua::PersistentRef err(L, errEntry); + errEntry.ResetWithoutUnbind(); + lua_pop(L, 1); + auto future = AsyncOsiFuture::New(L, queryid); + gExtender->GetServer().EnqueueTask([err=std::move(err), queryid]{ + LuaServerPin pin(esv::ExtensionState::Get()); + if (pin) + { + pin->ResolveOsirisFuture(queryid, std::span(&err, 1), false); + } + }); + } + else + { + std::vector entries; + for (int i = 0; i < num; i++) + { + PersistentRegistryEntry entry(L, -1 - i); + entries.emplace_back(L, entry); + entry.ResetWithoutUnbind(); + } + lua_pop(L, num); + auto future = AsyncOsiFuture::New(L, queryid); + gExtender->GetServer().EnqueueTask([entries=std::move(entries), queryid]{ + LuaServerPin pin(esv::ExtensionState::Get()); + if (pin) + { + pin->ResolveOsirisFuture(queryid, std::span(entries), true); + } + }); + } + return 1; + } + + int AsyncOsiFunctionNameProxy::LuaCall(lua_State * L) + { + return SharedLuaLogic(DoLuaCall(L), L); + } + + int AsyncOsiFunctionNameProxy::LuaGet(lua_State * L) + { + return SharedLuaLogic(DoLuaGet(AsyncOsiFunctionNameProxy::CheckUserData(L, 1), L), L); + } + + int AsyncOsiFunctionNameProxy::LuaDelete(lua_State * L) + { + return SharedLuaLogic(DoLuaDelete(AsyncOsiFunctionNameProxy::CheckUserData(L, 1), L), L); + } + + int AsyncOsiFunctionNameProxy::LuaDeferredNotification(lua_State * L) + { + return SharedLuaLogic(DoLuaDeferredNotification(AsyncOsiFunctionNameProxy::CheckUserData(L, 1), L), L); + } + + bool CustomLuaCall::Call(OsiArgumentDesc const & params) { @@ -1157,6 +1350,41 @@ namespace bg3se::esv::lua return 1; } + + char const * const ExtensionLibraryServer::AsyncNameResolverMetatableName = "AsyncOsiProxyNameResolver"; + + void ExtensionLibraryServer::RegisterAsyncNameResolverMetatable(lua_State * L) + { + lua_register(L, AsyncNameResolverMetatableName, nullptr); + luaL_newmetatable(L, AsyncNameResolverMetatableName); // stack: mt + lua_pushcfunction(L, &AsyncLuaIndexResolverTable); // stack: mt, &LuaIndexResolverTable + lua_setfield(L, -2, "__index"); // mt.__index = &LuaIndexResolverTable; stack: mt + lua_pop(L, 1); // stack: mt + } + + void ExtensionLibraryServer::CreateAsyncNameResolver(lua_State * L) + { + lua_newtable(L); // stack: osi + luaL_setmetatable(L, AsyncNameResolverMetatableName); // stack: osi + lua_setglobal(L, "AsyncOsi"); // stack: - + } + + int ExtensionLibraryServer::AsyncLuaIndexResolverTable(lua_State * L) + { + luaL_checktype(L, 1, LUA_TTABLE); + auto name = luaL_checkstring(L, 2); + + LuaServerPin lua(ExtensionState::Get()); + AsyncOsiFunctionNameProxy::New(L, name, std::ref(lua.Get())); // stack: tab, name, proxy + + lua_pushvalue(L, 1); // stack: fun, tab + push(L, name); // stack: fun, tab, name + lua_pushvalue(L, -3); // stack: fun, tab, name, fun + lua_rawset(L, -3); // stack: fun + lua_pop(L, 1); + return 1; + } + STDString ExtensionLibraryServer::GenerateOsiHelpers() { std::stringstream ss; diff --git a/BG3Extender/Lua/Server/LuaBindingServer.h b/BG3Extender/Lua/Server/LuaBindingServer.h index d8113b0b5..511e802ec 100644 --- a/BG3Extender/Lua/Server/LuaBindingServer.h +++ b/BG3Extender/Lua/Server/LuaBindingServer.h @@ -240,11 +240,15 @@ namespace bg3se::esv::lua private: static char const * const NameResolverMetatableName; + static char const * const AsyncNameResolverMetatableName; void RegisterNameResolverMetatable(lua_State * L); void CreateNameResolver(lua_State * L); + void RegisterAsyncNameResolverMetatable(lua_State * L); + void CreateAsyncNameResolver(lua_State * L); static int LuaIndexResolverTable(lua_State* L); + static int AsyncLuaIndexResolverTable(lua_State* L); }; class ServerState : public State diff --git a/BG3Extender/Lua/Server/LuaOsirisBinding.h b/BG3Extender/Lua/Server/LuaOsirisBinding.h index ced15bb3d..569dd4844 100644 --- a/BG3Extender/Lua/Server/LuaOsirisBinding.h +++ b/BG3Extender/Lua/Server/LuaOsirisBinding.h @@ -58,6 +58,8 @@ void LuaToOsi(lua_State * L, int i, OsiArgumentValue & arg, ValueType osiType, b void OsiToLua(lua_State * L, OsiArgumentValue const & arg); void OsiToLua(lua_State * L, TypedValue const & tv); +std::pair BuildOsiError(lua_State * L, const char *fmt, ...); + class OsiFunction { public: @@ -74,27 +76,70 @@ class OsiFunction bool Bind(Function const * func, class ServerState & state); void Unbind(); - int LuaCall(lua_State * L); - int LuaGet(lua_State * L); - int LuaDelete(lua_State * L); - int LuaDeferredNotification(lua_State * L); + std::pair LuaCall(lua_State * L); + std::pair LuaGet(lua_State * L); + std::pair LuaDelete(lua_State * L); + std::pair LuaDeferredNotification(lua_State * L); private: Function const * function_{ nullptr }; AdapterRef adapter_; ServerState * state_; - void OsiCall(lua_State * L); - void OsiDeferredNotification(lua_State * L); - void OsiInsert(lua_State * L, bool deleteTuple); - int OsiQuery(lua_State * L); - int OsiUserQuery(lua_State * L); + std::pair OsiCall(lua_State * L); + std::pair OsiDeferredNotification(lua_State * L); + std::pair OsiInsert(lua_State * L, bool deleteTuple); + std::pair OsiQuery(lua_State * L); + std::pair OsiUserQuery(lua_State * L); bool MatchTuple(lua_State * L, int firstIndex, TupleVec const & tuple); void ConstructTuple(lua_State * L, TupleVec const & tuple); }; -class OsiFunctionNameProxy : public Userdata, public Callable +class OsiFunctionNameProxyBase; + +class OsiFunctionNameProxyBase +{ +public: + static constexpr uint32_t MaxQueryOutParams = 6; + + OsiFunctionNameProxyBase(STDString const & name, ServerState & state); + + void UnbindAll(); + +protected: + STDString name_; + Vector functions_; + ServerState & state_; + uint32_t generationId_; + + std::pair DoLuaCall(lua_State * L); + static std::pair DoLuaGet(OsiFunctionNameProxyBase* self, lua_State * L); + static std::pair DoLuaDelete(OsiFunctionNameProxyBase* self, lua_State * L); + static std::pair DoLuaDeferredNotification(OsiFunctionNameProxyBase* self, lua_State * L); + bool BeforeCall(lua_State * L); + OsiFunction * TryGetFunction(uint32_t arity); + OsiFunction * CreateFunctionMapping(uint32_t arity, Function const * func); +}; + +class OsiFunctionNameProxy : public OsiFunctionNameProxyBase, public Userdata, public Callable +{ +public: + static char const * const MetatableName; + // Maximum number of OUT params that a query can return. + // (This setting determines how many function arities we'll check during name lookup) + + static void PopulateMetatable(lua_State * L); + + using OsiFunctionNameProxyBase::OsiFunctionNameProxyBase; + + int LuaCall(lua_State * L); + static int LuaGet(lua_State * L); + static int LuaDelete(lua_State * L); + static int LuaDeferredNotification(lua_State * L); +}; + +class AsyncOsiFunctionNameProxy : public OsiFunctionNameProxyBase, public Userdata, public Callable { public: static char const * const MetatableName; @@ -104,23 +149,18 @@ class OsiFunctionNameProxy : public Userdata, public Calla static void PopulateMetatable(lua_State * L); - OsiFunctionNameProxy(STDString const & name, ServerState & state); + using OsiFunctionNameProxyBase::OsiFunctionNameProxyBase; - void UnbindAll(); int LuaCall(lua_State * L); private: - STDString name_; - Vector functions_; - ServerState & state_; - uint32_t generationId_; + static std::atomic currQueryId; static int LuaGet(lua_State * L); static int LuaDelete(lua_State * L); static int LuaDeferredNotification(lua_State * L); - bool BeforeCall(lua_State * L); - OsiFunction * TryGetFunction(uint32_t arity); - OsiFunction * CreateFunctionMapping(uint32_t arity, Function const * func); + + static int SharedLuaLogic(const std::pair& result, lua_State * L); }; diff --git a/BG3Extender/Lua/Server/LuaServer.cpp b/BG3Extender/Lua/Server/LuaServer.cpp index 19a9691f4..7b6cb6157 100644 --- a/BG3Extender/Lua/Server/LuaServer.cpp +++ b/BG3Extender/Lua/Server/LuaServer.cpp @@ -125,6 +125,9 @@ namespace bg3se::esv::lua OsiFunctionNameProxy::RegisterMetatable(L); RegisterNameResolverMetatable(L); CreateNameResolver(L); + AsyncOsiFunctionNameProxy::RegisterMetatable(L); + RegisterAsyncNameResolverMetatable(L); + CreateAsyncNameResolver(L); } From 3d29ba8adee52f8c120e20a9228022978c1e9a06 Mon Sep 17 00:00:00 2001 From: Chris Feger Date: Tue, 19 Mar 2024 14:01:56 -0400 Subject: [PATCH 3/3] Set accidentally-forgotten indices on query args --- BG3Extender/Lua/LuaOsiBridge.cpp | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/BG3Extender/Lua/LuaOsiBridge.cpp b/BG3Extender/Lua/LuaOsiBridge.cpp index ea5db38ea..5359bb522 100644 --- a/BG3Extender/Lua/LuaOsiBridge.cpp +++ b/BG3Extender/Lua/LuaOsiBridge.cpp @@ -1533,22 +1533,32 @@ namespace bg3se::ecl::lua case LUA_TLIGHTUSERDATA: { auto handle = get(L, idx); - query->add_args()->set_intv((int64_t)handle.Handle); + auto arg = query->add_args(); + arg->set_index(i); + arg->set_intv((int64_t)handle.Handle); break; } case LUA_TNUMBER: + { + auto arg = query->add_args(); + arg->set_index(i); if (lua_isinteger(L, idx)) { - query->add_args()->set_intv(lua_tointeger(L, idx)); + arg->set_intv(lua_tointeger(L, idx)); } else { - query->add_args()->set_numv((float)lua_tonumber(L, idx)); + arg->set_numv((float)lua_tonumber(L, idx)); } break; + } case LUA_TSTRING: - query->add_args()->set_strv(lua_tostring(L, idx)); + { + auto arg = query->add_args(); + arg->set_index(i); + arg->set_strv(lua_tostring(L, idx)); break; + } } }