diff --git a/src/msgpack.lua b/src/msgpack.lua index 34ddd1b..f4a5111 100644 --- a/src/msgpack.lua +++ b/src/msgpack.lua @@ -347,163 +347,18 @@ local function parse(message: buffer, offset: number): (any, number) error("Not all decoder cases are handled, report as bug to msgpack-luau maintainer") end -local function computeLength(data: any, tableSet: {[any]: boolean}): number - local dtype = type(data) - if data == nil then - return 1 - elseif dtype == "boolean" then - return 1 - elseif dtype == "string" then - local length = #data - - if length <= 31 then - return 1 + length - elseif length <= 0xFF then - return 2 + length - elseif length <= 0xFFFF then - return 3 + length - elseif length <= 0xFFFFFFFF then - return 5 + length - end - - error("Could not encode - too long string") - - elseif dtype == "buffer" then - local length = bufferLen(data) - - if length <= 0xFF then - return 2 + length - elseif length <= 0xFFFF then - return 3 + length - elseif length <= 0xFFFFFFFF then - return 5 + length - end - - error("Could not encode - too long binary buffer") - - elseif dtype == "number" then - -- represents NaN, Inf, -Inf as float 32 to save space - if data == 0 then - return 1 - elseif data ~= data then -- NaN - return 5 - elseif data == math.huge then - return 5 - elseif data == -math.huge then - return 5 - end - - local integral, fractional = modf(data) - local sign = sign(data) - - if fractional ~= 0 or integral > 0xFFFFFFFF or integral < -0x80000000 then - -- float 64 - return 9 - end - - if sign > 0 then - if integral <= 127 then -- positive fixint - return 1 - elseif integral <= 0xFF then -- uint 8 - return 2 - elseif integral <= 0xFFFF then -- uint 16 - return 3 - elseif integral <= 0xFFFFFFFF then -- uint 32 - return 5 - end - else - if integral >= -0x20 then -- negative fixint - return 1 - elseif integral >= -0x80 then -- int 8 - return 2 - elseif integral >= -0x8000 then -- int 16 - return 3 - elseif integral >= -0x80000000 then -- int 32 - return 5 - end - end - - error(string.format("Could not encode - unhandled number \"%s\"", typeof(data))) - - elseif dtype == "table" then - local msgpackType = data._msgpackType - - if msgpackType then - if msgpackType == msgpack.Int64 or msgpackType == msgpack.UInt64 then - return 9 - elseif msgpackType == msgpack.Extension then - local length = bufferLen(data.data) - - if length == 1 then - return 3 - elseif length == 2 then - return 4 - elseif length == 4 then - return 6 - elseif length == 8 then - return 10 - elseif length == 16 then - return 18 - elseif length <= 0xFF then - return 3 + length - elseif length <= 0xFFFF then - return 4 + length - elseif length <= 0xFFFFFFFF then - return 6 + length - end - - error("Could not encode - too long extension data") - end - end - - if tableSet[data] then - error("Can not serialize cyclic table") - else - tableSet[data] = true - end - - local length = #data - local mapLength = 0 - - for _,_ in pairs(data) do - mapLength += 1 - end - - local headerLen - if mapLength <= 15 then - headerLen = 1 - elseif mapLength <= 0xFFFF then - headerLen = 3 - elseif mapLength <= 0xFFFFFFFF then - headerLen = 5 - else - if length == mapLength then - error("Could not encode - too long array") - else - error("Could not encode - too long map") - end - end - - if length == mapLength then -- array - local contentLen = 0 - for _,v in ipairs(data) do - contentLen += computeLength(v, tableSet) - end - - return headerLen + contentLen - - else -- map - local contentLen = 0 - for k,v in pairs(data) do - contentLen += computeLength(k, tableSet) - contentLen += computeLength(v, tableSet) - end +local function inflate(result: buffer, minSize: number, oldSize: number) + if oldSize == 0 then + return bufferCreate(minSize), minSize + end - return headerLen + contentLen - end + while minSize > oldSize do + oldSize *= 2 end - error(string.format("Could not encode - unsupported datatype \"%s\"", typeof(data))) + local temp = bufferCreate(oldSize) + bufferCopy(temp, 0, result) + return temp, oldSize end local extensionTypeLUT = { @@ -514,40 +369,74 @@ local extensionTypeLUT = { [16] = 0xD8, } -local function encode(result: buffer, offset: number, data: any): number +local function encode( + result: buffer, + offset: number, + size: number, + data: any, + tableSet: { [any]: boolean } +): (buffer, number, number) local dtype = type(data) if data == nil then + if offset + 1 > size then + result, size = inflate(result, offset + 1, size) + end + writestring(result, offset, "\xC0") - return offset + 1 + return result, offset + 1, size elseif data == false then + if offset + 1 > size then + result, size = inflate(result, offset + 1, size) + end + writestring(result, offset, "\xC2") - return offset + 1 + return result, offset + 1, size elseif data == true then + if offset + 1 > size then + result, size = inflate(result, offset + 1, size) + end + writestring(result, offset, "\xC3") - return offset + 1 + return result, offset + 1, size elseif dtype == "string" then local length = #data if length <= 31 then + if offset + 1 + length > size then + result, size = inflate(result, offset + 1 + length, size) + end + writeu8(result, offset, bor(0xA0, length)) writestring(result, offset + 1, data) - return offset + 1 + length + return result, offset + 1 + length, size elseif length <= 0xFF then + if offset + 2 + length > size then + result, size = inflate(result, offset + 2 + length, size) + end + writeu8(result, offset, 0xD9) writeu8(result, offset + 1, length) writestring(result, offset + 2, data) - return offset + 2 + length + return result, offset + 2 + length, size elseif length <= 0xFFFF then + if offset + 3 + length > size then + result, size = inflate(result, offset + 3 + length, size) + end + writeu8(result, offset, 0xDA) writeu16(result, offset + 1, length) writestring(result, offset + 3, data) - return offset + 3 + length + return result, offset + 3 + length, size elseif length <= 0xFFFFFFFF then + if offset + 5 + length > size then + result, size = inflate(result, offset + 5 + length, size) + end + writeu8(result, offset, 0xDB) writeu32(result, offset + 1, length) writestring(result, offset + 5, data) - return offset + 5 + length + return result, offset + 5 + length, size end error("Could not encode - too long string") @@ -556,20 +445,32 @@ local function encode(result: buffer, offset: number, data: any): number local length = bufferLen(data) if length <= 0xFF then + if offset + 2 + length > size then + result, size = inflate(result, offset + 2 + length, size) + end + writeu8(result, offset, 0xC4) writeu8(result, offset + 1, length) bufferCopy(result, offset + 2, data) - return offset + 2 + length + return result, offset + 2 + length, size elseif length <= 0xFFFF then + if offset + 3 + length > size then + result, size = inflate(result, offset + 3 + length, size) + end + writeu8(result, offset, 0xC5) writeu16(result, offset + 1, length) bufferCopy(result, offset + 3, data) - return offset + 3 + length + return result, offset + 3 + length, size elseif length <= 0xFFFFFFFF then + if offset + 5 + length > size then + result, size = inflate(result, offset + 5 + length, size) + end + writeu8(result, offset, 0xC6) writeu32(result, offset + 1, length) bufferCopy(result, offset + 5, data) - return offset + 5 + length + return result, offset + 5 + length, size end error("Could not encode - too long binary buffer") @@ -577,62 +478,114 @@ local function encode(result: buffer, offset: number, data: any): number elseif dtype == "number" then -- represents NaN, Inf, -Inf as float 32 to save space if data == 0 then + if offset + 1 > size then + result, size = inflate(result, offset + 1, size) + end + writeu8(result, offset, 0) - return offset + 1 + return result, offset + 1, size elseif data ~= data then -- NaN + if offset + 5 > size then + result, size = inflate(result, offset + 5, size) + end + writestring(result, offset, "\xCA\x7F\x80\x00\x01") - return offset + 5 + return result, offset + 5, size elseif data == math.huge then + if offset + 5 > size then + result, size = inflate(result, offset + 5, size) + end + writestring(result, offset, "\xCA\x7F\x80\x00\x00") - return offset + 5 + return result, offset + 5, size elseif data == -math.huge then + if offset + 5 > size then + result, size = inflate(result, offset + 5, size) + end + writestring(result, offset, "\xCA\xFF\x80\x00\x00") - return offset + 5 + return result, offset + 5, size end local integral, fractional = modf(data) local sign = sign(data) if fractional ~= 0 or integral > 0xFFFFFFFF or integral < -0x80000000 then + if offset + 9 > size then + result, size = inflate(result, offset + 9, size) + end + -- float 64 writeu8(result, offset, 0xCB) writef64(result, offset + 1, data) - return offset + 9 + return result, offset + 9, size end if sign > 0 then if integral <= 127 then -- positive fixint + if offset + 1 > size then + result, size = inflate(result, offset + 1, size) + end + writeu8(result, offset, integral) - return offset + 1 + return result, offset + 1, size elseif integral <= 0xFF then -- uint 8 + if offset + 2 > size then + result, size = inflate(result, offset + 2, size) + end + writeu8(result, offset, 0xCC) writeu8(result, offset + 1, integral) - return offset + 2 + return result, offset + 2, size elseif integral <= 0xFFFF then -- uint 16 + if offset + 3 > size then + result, size = inflate(result, offset + 3, size) + end + writeu8(result, offset, 0xCD) writeu16(result, offset + 1, integral) - return offset + 3 + return result, offset + 3, size elseif integral <= 0xFFFFFFFF then -- uint 32 + if offset + 5 > size then + result, size = inflate(result, offset + 5, size) + end + writeu8(result, offset, 0xCE) writeu32(result, offset + 1, integral) - return offset + 5 + return result, offset + 5, size end else if integral >= -0x20 then -- negative fixint + if offset + 1 > size then + result, size = inflate(result, offset + 1, size) + end + writeu8(result, offset, bor(0xE0, extract(integral, 0, 5))) - return offset + 1 + return result, offset + 1, size elseif integral >= -0x80 then -- int 8 + if offset + 2 > size then + result, size = inflate(result, offset + 2, size) + end + writeu8(result, offset, 0xD0) writei8(result, offset + 1, integral) - return offset + 2 + return result, offset + 2, size elseif integral >= -0x8000 then -- int 16 + if offset + 3 > size then + result, size = inflate(result, offset + 3, size) + end + writeu8(result, offset, 0xD1) writei16(result, offset + 1, integral) - return offset + 3 + return result, offset + 3, size elseif integral >= -0x80000000 then -- int 32 + if offset + 5 > size then + result, size = inflate(result, offset + 5, size) + end + writeu8(result, offset, 0xD2) writei32(result, offset + 1, integral) - return offset + 5 + return result, offset + 5, size end end @@ -643,46 +596,72 @@ local function encode(result: buffer, offset: number, data: any): number if msgpackType then if msgpackType == msgpack.Int64 or msgpackType == msgpack.UInt64 then + if offset + 9 > size then + result, size = inflate(result, offset + 9, size) + end + local intType = if msgpackType == msgpack.UInt64 then 0xCF else 0xD3 writeu8(result, offset, intType) writeu32(result, offset + 1, data.mostSignificantPart) writeu32(result, offset + 5, data.leastSignificantPart) - return offset + 9 + return result, offset + 9, size elseif msgpackType == msgpack.Extension then local length = bufferLen(data.data) local extType = extensionTypeLUT[length] if extType then + if offset + 2 + length > size then + result, size = inflate(result, offset + 2 + length, size) + end + writeu8(result, offset, extType) writeu8(result, offset + 1, data.type) bufferCopy(result, offset + 2, data.data) - return offset + 2 + length + return result, offset + 2 + length, size end if length <= 0xFF then + if offset + 3 + length > size then + result, size = inflate(result, offset + 3 + length, size) + end + writeu8(result, offset, 0xC7) writeu8(result, offset + 1, length) writeu8(result, offset + 2, data.type) bufferCopy(result, offset + 3, data.data) - return offset + 3 + length + return result, offset + 3 + length, size elseif length <= 0xFFFF then + if offset + 4 + length > size then + result, size = inflate(result, offset + 4 + length, size) + end + writeu8(result, offset, 0xC8) writeu16(result, offset + 1, length) writeu8(result, offset + 3, data.type) bufferCopy(result, offset + 4, data.data) - return offset + 4 + length + return result, offset + 4 + length, size elseif length <= 0xFFFFFFFF then + if offset + 6 + length > size then + result, size = inflate(result, offset + 6 + length, size) + end + writeu8(result, offset, 0xC9) writeu32(result, offset + 1, length) writeu8(result, offset + 5, data.type) bufferCopy(result, offset + 6, data.data) - return offset + 6 + length + return result, offset + 6 + length, size end error("Could not encode - too long extension data") end end + if tableSet[data] then + error("Can not serialize cyclic table") + else + tableSet[data] = true + end + local length = #data local mapLength = 0 @@ -693,13 +672,25 @@ local function encode(result: buffer, offset: number, data: any): number if length == mapLength then -- array local newOffset = offset if length <= 15 then + if offset + 1 > size then + result, size = inflate(result, offset + 1, size) + end + writeu8(result, offset, bor(0x90, mapLength)) newOffset += 1 elseif length <= 0xFFFF then + if offset + 3 > size then + result, size = inflate(result, offset + 3, size) + end + writeu8(result, offset, 0xDC) writeu16(result, offset + 1, length) newOffset += 3 elseif length <= 0xFFFFFFFF then + if offset + 5 > size then + result, size = inflate(result, offset + 5, size) + end + writeu8(result, offset, 0xDD) writeu32(result, offset + 1, length) newOffset += 5 @@ -708,21 +699,33 @@ local function encode(result: buffer, offset: number, data: any): number end for _,v in ipairs(data) do - newOffset = encode(result, newOffset, v) + result, newOffset, size = encode(result, newOffset, size, v, tableSet) end - return newOffset + return result, newOffset, size else -- map local newOffset = offset if mapLength <= 15 then + if offset + 1 > size then + result, size = inflate(result, offset + 1, size) + end + writeu8(result, offset, bor(0x80, mapLength)) newOffset += 1 elseif mapLength <= 0xFFFF then + if offset + 3 > size then + result, size = inflate(result, offset + 3, size) + end + writeu8(result, offset, 0xDE) writeu16(result, offset + 1, mapLength) newOffset += 3 elseif mapLength <= 0xFFFFFFFF then + if offset + 5 > size then + result, size = inflate(result, offset + 5, size) + end + writeu8(result, offset, 0xDF) writeu32(result, offset + 1, mapLength) newOffset += 5 @@ -731,11 +734,11 @@ local function encode(result: buffer, offset: number, data: any): number end for k,v in pairs(data) do - newOffset = encode(result, newOffset, k) - newOffset = encode(result, newOffset, v) + result, newOffset, size = encode(result, newOffset, size, k, tableSet) + result, newOffset, size = encode(result, newOffset, size, v, tableSet) end - return newOffset + return result, newOffset, size end end @@ -831,10 +834,8 @@ function msgpack.decode(message: string): any end function msgpack.encode(data: any): string - local length = computeLength(data, {}) - local result = bufferCreate(length) - encode(result, 0, data) - return buffer.tostring(result) + local result, offset = encode(bufferCreate(64), 0, 64, data, {}) + return readstring(result, 0, offset) end export type Int64 = { _msgpackType: typeof(msgpack.Int64), mostSignificantPart: number, leastSignificantPart: number }