diff --git a/lua_protobuf/generator.py b/lua_protobuf/generator.py index 7081667..8ec2aa6 100644 --- a/lua_protobuf/generator.py +++ b/lua_protobuf/generator.py @@ -79,6 +79,8 @@ def lua_protobuf_header(): // GC callback function that always returns true LUA_PROTOBUF_EXPORT int lua_protobuf_gc_always_free(::google::protobuf::Message *msg, void *userdata); +LUA_PROTOBUF_EXPORT int lua_protobuf_find_or_create_nested_table (lua_State *L, const char *fname, int szhint); + #ifdef __cplusplus } #endif @@ -118,6 +120,30 @@ def lua_protobuf_source(): return 1; } +int lua_protobuf_find_or_create_nested_table (lua_State *L, const char *fname, int szhint) { + const char *e; + lua_pushglobaltable(L); + do { + e = strchr(fname, '.'); + if (e == NULL) e = fname + strlen(fname); + lua_pushlstring(L, fname, e - fname); + if (lua_rawget(L, -2) == LUA_TNIL) { /* no such field? */ + lua_pop(L, 1); /* remove this nil */ + lua_createtable(L, 0, (*e == '.' ? 1 : szhint)); /* new table for field */ + lua_pushlstring(L, fname, e - fname); + lua_pushvalue(L, -2); + lua_settable(L, -4); /* set new table into field */ + } + else if (!lua_istable(L, -1)) { /* field has a non-table value? */ + lua_pop(L, 2); /* remove table and value */ + return 0; /* return problematic part of the name */ + } + lua_remove(L, -2); /* remove previous table */ + fname = e + 1; + } while (*e == '.'); + return 1; +} + ''' def c_header_header(filename, package): @@ -334,7 +360,7 @@ def field_get(package, message, field_descriptor): lines.append('lua_pushnumber(L, m->%s(index-1));' % name) elif type == FieldDescriptor.TYPE_ENUM: - lines.append('lua_pushnumber(L, m->%s(index-1));' % name) + lines.append('lua_pushinteger(L, m->%s(index-1));' % name) elif type == FieldDescriptor.TYPE_MESSAGE: lines.extend([ @@ -349,24 +375,24 @@ def field_get(package, message, field_descriptor): # this is the Lua way if type == FieldDescriptor.TYPE_STRING or type == FieldDescriptor.TYPE_BYTES: lines.append('string s = m->%s();' % name) - lines.append('m->has_%s() ? lua_pushlstring(L, s.c_str(), s.size()) : lua_pushnil(L);' % name) + lines.append('if (m->has_%s()) { lua_pushlstring(L, s.c_str(), s.size()); } else { lua_pushnil(L); }' % name) elif type == FieldDescriptor.TYPE_BOOL: - lines.append('m->has_%s() ? lua_pushboolean(L, m->%s()) : lua_pushnil(L);' % ( name, name )) + lines.append('if (m->has_%s()) { lua_pushboolean(L, m->%s()); } else { lua_pushnil(L); }' % ( name, name )) elif type in [FieldDescriptor.TYPE_INT32, FieldDescriptor.TYPE_UINT32, FieldDescriptor.TYPE_FIXED32, FieldDescriptor.TYPE_SFIXED32, FieldDescriptor.TYPE_SINT32]: - lines.append('m->has_%s() ? lua_pushinteger(L, m->%s()) : lua_pushnil(L);' % ( name, name )) + lines.append('if (m->has_%s()) { lua_pushinteger(L, m->%s()); } else { lua_pushnil(L); }' % ( name, name )) elif type in [ FieldDescriptor.TYPE_INT64, FieldDescriptor.TYPE_UINT64, FieldDescriptor.TYPE_FIXED64, FieldDescriptor.TYPE_SFIXED64, FieldDescriptor.TYPE_SINT64]: - lines.append('m->has_%s() ? lua_pushinteger(L, m->%s()) : lua_pushnil(L);' % ( name, name )) + lines.append('if (m->has_%s()) { lua_pushinteger(L, m->%s()); } else { lua_pushnil(L); }' % ( name, name )) elif type == FieldDescriptor.TYPE_FLOAT or type == FieldDescriptor.TYPE_DOUBLE: - lines.append('m->has_%s() ? lua_pushnumber(L, m->%s()) : lua_pushnil(L);' % ( name, name )) + lines.append('if (m->has_%s()) { lua_pushnumber(L, m->%s()); } else { lua_pushnil(L); }' % ( name, name )) elif type == FieldDescriptor.TYPE_ENUM: - lines.append('m->has_%s() ? lua_pushinteger(L, m->%s()) : lua_pushnil(L);' % ( name, name )) + lines.append('if (m->has_%s()) { lua_pushinteger(L, m->%s()); } else { lua_pushnil(L); }' % ( name, name )) elif type == FieldDescriptor.TYPE_MESSAGE: lines.extend([ @@ -765,11 +791,15 @@ def message_open_function(package, descriptor): lines = [ 'int %s(lua_State *L)' % message_open_function_name(package, message), '{', - 'luaL_newmetatable(L, "%s");' % metatable(package, message), - 'lua_pushvalue(L, -1);', - 'lua_setfield(L, -2, "__index");', - 'luaL_register(L, NULL, %s_methods);' % message, - 'luaL_register(L, "%s", %s_functions);' % (lua_libname(package, message), message), + 'luaL_newmetatable(L, "%s");' % metatable(package, message), # stack: mt + 'lua_pushvalue(L, -1);', # stack: mt mt + 'lua_setfield(L, -2, "__index");', # stack: mt + 'luaL_setfuncs(L, %s_methods, 0);' % message, # stack: mt + 'if (!lua_protobuf_find_or_create_nested_table(L, "%s", 1)) {' % lua_libname(package, message), # stack: mt nested_table + 'return luaL_error(L, "could not create nested lua table %s");' % lua_libname(package, message), + '}', + 'luaL_setfuncs(L, %s_functions, 0);' % message, # stack: mt nested_table + 'lua_pop(L, 1);', # stack: mt ] for enum_descriptor in descriptor.enum_type: @@ -777,8 +807,8 @@ def message_open_function(package, descriptor): lines.extend([ # this is wrong if we are calling through normal Lua module load means - 'lua_pop(L, 1);', - 'return 1;', + 'lua_pop(L, 1);', # stack: (empty) + 'return 1;', # return 1 means "true" instead of "one value to return" '}', '\n', ]) @@ -958,7 +988,7 @@ def enum_source(descriptor): k = value.name v = value.number lines.extend([ - 'lua_pushnumber(L, %d);' % v, + 'lua_pushinteger(L, %d);' % v, 'lua_setfield(L, -2, "%s");' % k ]) @@ -1031,17 +1061,9 @@ def file_source(file_descriptor): # i.e. protobuf.package.foo.enum => protobuf['package']['foo']['enum'] # we interate over all the tables and create missing ones, as necessary - # we cheat here and use the undocumented/internal luaL_findtable function - # we probably shouldn't rely on an "internal" API, so - # TODO don't use internal API call lines.extend([ - 'const char *table = luaL_findtable(L, LUA_GLOBALSINDEX, "protobuf.%s", 1);' % package, - 'if (table) {', - 'return luaL_error(L, "could not create parent Lua tables");', - '}', - 'if (!lua_istable(L, -1)) {', - 'lua_newtable(L);', - 'lua_setfield(L, -2, "%s");' % package, + 'if (!lua_protobuf_find_or_create_nested_table(L, "protobuf.%s", 1)) {' % package, # stack: nested_table + 'return luaL_error(L, "could not create nested lua table protobuf.%s");' % package, '}', ]) @@ -1049,18 +1071,15 @@ def file_source(file_descriptor): lines.extend(enum_source(descriptor)) lines.extend([ - # don't need main table on stack any more - 'lua_pop(L, 1);', - - # and we register this package as a module, complete with enumerations 'luaL_Reg funcs [] = { { NULL, NULL } };', - 'luaL_register(L, "protobuf.%s", funcs);' % package, + 'luaL_setfuncs(L, funcs, 0);', + 'lua_pop(L, 1);', # stack: (empty) ]) for descriptor in file_descriptor.message_type: lines.append('%s(L);' % message_open_function_name(package, descriptor.name)) - lines.append('return 1;') + lines.append('return 1;') # return 1 means "true" instead of "one value to return" lines.append('}') lines.append('\n')