From dda80d37d962ae96d483baa7752b378c74222ef5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Ciejka?= Date: Mon, 1 Nov 2021 11:02:56 +0000 Subject: [PATCH 01/10] feat: added doh support post/get --- README.markdown | 5 +- lib/resty/dns/resolver.lua | 96 ++++++++++++++++++++++++++++++++------ 2 files changed, 87 insertions(+), 14 deletions(-) diff --git a/README.markdown b/README.markdown index 28b062a..93b4b3c 100644 --- a/README.markdown +++ b/README.markdown @@ -125,7 +125,7 @@ It accepts a `opts` table argument. The following options are supported: * `nameservers` - a list of nameservers to be used. Each nameserver entry can be either a single hostname string or a table holding both the hostname string and the port number. The nameserver is picked up by a simple round-robin algorithm for each `query` method call. This option is required. + a list of nameservers to be used. Each nameserver entry can be either a single hostname string, DoH url with optional port or a table holding both the hostname string and the port number. The nameserver is picked up by a simple round-robin algorithm for each `query` method call. This option is required. * `retrans` the total number of times of retransmitting the DNS request when receiving a DNS response times out according to the `timeout` setting. Defaults to `5` times. When trying to retransmit the query, the next nameserver according to the round-robin algorithm will be picked up. @@ -138,7 +138,10 @@ It accepts a `opts` table argument. The following options are supported: * `no_random` a boolean flag controls whether to randomly pick the nameserver to query first, if `true` will always start with the first nameserver listed. Defaults to `false`. +* `doh` + type of DoH query possible values are `POST`, `GET` or boolean false, Defaults to nil. + [Back to TOC](#table-of-contents) query diff --git a/lib/resty/dns/resolver.lua b/lib/resty/dns/resolver.lua index a67b3c1..9100d36 100644 --- a/lib/resty/dns/resolver.lua +++ b/lib/resty/dns/resolver.lua @@ -25,7 +25,7 @@ local unpack = unpack local setmetatable = setmetatable local type = type local ipairs = ipairs - +local b64 = require "ngx.base64" local ok, new_tab = pcall(require, "table.new") if not ok then @@ -110,6 +110,10 @@ function _M.new(class, opts) return nil, "no nameservers specified" end + if opts.doh ~= nil and (opts.doh ~= 'POST' and opts.doh ~= 'GET') then + return nil, "invalid DoH mode specified" + end + local timeout = opts.timeout or 2000 -- default 2 sec local n = #servers @@ -118,30 +122,32 @@ function _M.new(class, opts) for i = 1, n do local server = servers[i] - local sock, err = udp() - if not sock then - return nil, "failed to create udp socket: " .. err - end - local host, port + if type(server) == 'table' then host = server[1] port = server[2] or 53 - else host = server port = 53 servers[i] = {host, port} end - local ok, err = sock:setpeername(host, port) - if not ok then - return nil, "failed to set peer name: " .. err - end + if not opts.doh then + local sock, err = udp() + if not sock then + return nil, "failed to create udp socket: " .. err + end - sock:settimeout(timeout) + local ok, err = sock:setpeername(host, port) + if not ok then + return nil, "failed to set peer name: " .. err + end + + sock:settimeout(timeout) - insert(socks, sock) + insert(socks, sock) + end end local tcp_sock, err = tcp() @@ -158,6 +164,7 @@ function _M.new(class, opts) servers = servers, retrans = opts.retrans or 5, no_recurse = opts.no_recurse, + doh = opts.doh }, mt) end @@ -832,6 +839,69 @@ end function _M.query(self, qname, opts, tries) + if self.doh then + return _M.doh_query(self,qname,opts,tries) + end + + return _M.udp_tcp_query(self,qname,opts,tries) +end + +function _M.doh_query(self, qname, opts, tries) + local retrans = self.retrans + if tries then + tries[1] = nil + end + + local servers = self.servers + + if #servers == 0 then + return nil, "No servers available" + end + + local err + + for i = 1, retrans do + local idx = i + + if idx > #servers then + idx = 1 + end + + local res + local id + + if self.doh == 'GET' then + res = ngx.location.capture(servers[idx][1] .. b64.encode_base64url(qname)) + else + id = _gen_id(self) + res = ngx.location.capture(servers[idx][1], + { method = ngx.HTTP_POST, body = _build_request(qname, id, self.no_recurse, opts) }) + end + + if res.status == 200 and res.body then + local answers + if self.doh == 'GET' then + local ident_hi = byte(res.body, 1) + local ident_lo = byte(res.body, 2) + id = lshift(ident_hi, 8) + ident_lo + end + answers, err = parse_response(res.body, id, opts) + if answers then + return answers, nil, tries + end + end + + if tries then + tries[i] = err + tries[i + 1] = nil -- ensure termination for user supplied table + end + end + + return nil, err, tries +end + + +function _M.udp_tcp_query(self, qname, opts, tries) local socks = self.socks if not socks then return nil, "not initialized" From 810558357ccb4917fce31023e61c6dbecb68c7f3 Mon Sep 17 00:00:00 2001 From: Pawel Ciejka Date: Mon, 1 Nov 2021 20:42:26 +0000 Subject: [PATCH 02/10] bugfix: convert table to string using concat in doh post request --- lib/resty/dns/resolver.lua | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/resty/dns/resolver.lua b/lib/resty/dns/resolver.lua index 9100d36..b5e2138 100644 --- a/lib/resty/dns/resolver.lua +++ b/lib/resty/dns/resolver.lua @@ -874,8 +874,8 @@ function _M.doh_query(self, qname, opts, tries) res = ngx.location.capture(servers[idx][1] .. b64.encode_base64url(qname)) else id = _gen_id(self) - res = ngx.location.capture(servers[idx][1], - { method = ngx.HTTP_POST, body = _build_request(qname, id, self.no_recurse, opts) }) + local bdata = table.concat(_build_request(qname, id, self.no_recurse, opts)) + res = ngx.location.capture(servers[idx][1],{ method = ngx.HTTP_POST, body = bdata }) end if res.status == 200 and res.body then From ef0ef257b6cdc995c40215319bfdd295e20eaede Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Ciejka?= Date: Wed, 3 Nov 2021 20:21:50 +0000 Subject: [PATCH 03/10] fix: moved sub query functions to local --- lib/resty/dns/resolver.lua | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/resty/dns/resolver.lua b/lib/resty/dns/resolver.lua index b5e2138..f74a085 100644 --- a/lib/resty/dns/resolver.lua +++ b/lib/resty/dns/resolver.lua @@ -840,13 +840,13 @@ end function _M.query(self, qname, opts, tries) if self.doh then - return _M.doh_query(self,qname,opts,tries) + return doh_query(self,qname,opts,tries) end - return _M.udp_tcp_query(self,qname,opts,tries) + return udp_tcp_query(self,qname,opts,tries) end -function _M.doh_query(self, qname, opts, tries) +local function doh_query(self, qname, opts, tries) local retrans = self.retrans if tries then tries[1] = nil @@ -901,7 +901,7 @@ function _M.doh_query(self, qname, opts, tries) end -function _M.udp_tcp_query(self, qname, opts, tries) +local function udp_tcp_query(self, qname, opts, tries) local socks = self.socks if not socks then return nil, "not initialized" From 90c590288a7e3a8b0b88a5b163498b37d8fe6ec8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Ciejka?= Date: Wed, 3 Nov 2021 20:30:46 +0000 Subject: [PATCH 04/10] fix: added nil check on response --- lib/resty/dns/resolver.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/resty/dns/resolver.lua b/lib/resty/dns/resolver.lua index f74a085..a222581 100644 --- a/lib/resty/dns/resolver.lua +++ b/lib/resty/dns/resolver.lua @@ -878,7 +878,7 @@ local function doh_query(self, qname, opts, tries) res = ngx.location.capture(servers[idx][1],{ method = ngx.HTTP_POST, body = bdata }) end - if res.status == 200 and res.body then + if res ~= nil and res.status == 200 and res.body then local answers if self.doh == 'GET' then local ident_hi = byte(res.body, 1) From c3b80c24fbf17a15223aad5127ed9a5c43103d49 Mon Sep 17 00:00:00 2001 From: Pawel Ciejka Date: Wed, 3 Nov 2021 20:37:51 +0000 Subject: [PATCH 05/10] fix: changed functions order --- lib/resty/dns/resolver.lua | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/lib/resty/dns/resolver.lua b/lib/resty/dns/resolver.lua index a222581..22fad95 100644 --- a/lib/resty/dns/resolver.lua +++ b/lib/resty/dns/resolver.lua @@ -838,14 +838,6 @@ function _M.tcp_query(self, qname, opts) end -function _M.query(self, qname, opts, tries) - if self.doh then - return doh_query(self,qname,opts,tries) - end - - return udp_tcp_query(self,qname,opts,tries) -end - local function doh_query(self, qname, opts, tries) local retrans = self.retrans if tries then @@ -974,6 +966,13 @@ local function udp_tcp_query(self, qname, opts, tries) return nil, err, tries end +function _M.query(self, qname, opts, tries) + if self.doh then + return doh_query(self,qname,opts,tries) + end + + return udp_tcp_query(self,qname,opts,tries) +end function _M.compress_ipv6_addr(addr) local addr = re_sub(addr, "^(0:)+|(:0)+$|:(0:)+", "::", "jo") From 5ecffd58eb9a4a5df94bfe13f631824c952b5856 Mon Sep 17 00:00:00 2001 From: Pawel Ciejka Date: Tue, 16 Nov 2021 23:17:11 +0000 Subject: [PATCH 06/10] fix: changed ngx.location.capture to internal simple http client --- lib/resty/dns/resolver.lua | 406 +++++++++++++++++++++++++++++-------- 1 file changed, 323 insertions(+), 83 deletions(-) diff --git a/lib/resty/dns/resolver.lua b/lib/resty/dns/resolver.lua index 22fad95..88b2eda 100644 --- a/lib/resty/dns/resolver.lua +++ b/lib/resty/dns/resolver.lua @@ -18,6 +18,8 @@ local lshift = bit.lshift local insert = table.insert local concat = table.concat local re_sub = ngx.re.sub +local re_match = ngx.re.match +local re_find = ngx.re.find local tcp = ngx.socket.tcp local log = ngx.log local DEBUG = ngx.DEBUG @@ -26,6 +28,11 @@ local setmetatable = setmetatable local type = type local ipairs = ipairs local b64 = require "ngx.base64" +local agent = "ngx_lua/" .. ngx.config.ngx_lua_version +local str_lower = string.lower +local ngx_ERR = ngx.ERR +local tbl_insert = table.insert +local tolower = string.lower local ok, new_tab = pcall(require, "table.new") if not ok then @@ -100,75 +107,6 @@ for i = 2, 64, 2 do end -function _M.new(class, opts) - if not opts then - return nil, "no options table specified" - end - - local servers = opts.nameservers - if not servers or #servers == 0 then - return nil, "no nameservers specified" - end - - if opts.doh ~= nil and (opts.doh ~= 'POST' and opts.doh ~= 'GET') then - return nil, "invalid DoH mode specified" - end - - local timeout = opts.timeout or 2000 -- default 2 sec - - local n = #servers - - local socks = {} - - for i = 1, n do - local server = servers[i] - local host, port - - if type(server) == 'table' then - host = server[1] - port = server[2] or 53 - else - host = server - port = 53 - servers[i] = {host, port} - end - - if not opts.doh then - local sock, err = udp() - if not sock then - return nil, "failed to create udp socket: " .. err - end - - local ok, err = sock:setpeername(host, port) - if not ok then - return nil, "failed to set peer name: " .. err - end - - sock:settimeout(timeout) - - insert(socks, sock) - end - end - - local tcp_sock, err = tcp() - if not tcp_sock then - return nil, "failed to create tcp socket: " .. err - end - - tcp_sock:settimeout(timeout) - - return setmetatable( - { cur = opts.no_random and 1 or rand(1, n), - socks = socks, - tcp_sock = tcp_sock, - servers = servers, - retrans = opts.retrans or 5, - no_recurse = opts.no_recurse, - doh = opts.doh - }, mt) -end - - local function pick_sock(self, socks) local cur = self.cur @@ -838,7 +776,186 @@ function _M.tcp_query(self, qname, opts) end -local function doh_query(self, qname, opts, tries) +local function _http_connect(self,host) + local sock = self.tcp_sock + if not sock then + return nil, "not initialized" + end + + local ok, err = sock:connect(host[1], host[2]) + if not ok then + return nil, "failed to connect to HTTP server " + .. host[1] .. ":" .. host[2] .. ": " .. err + end + + if host[4] and sock:getreusedtimes() == 0 then + local session, err = sock:sslhandshake(nil,host[1]) + if not session then + return nil, err + end + end + + return sock +end + + +local function _http_status_receive(sock) + local line, err, partial = sock:receive("*l") + if not line then + return nil, nil, nil, "failed to read http header status line: "..err + end + + local ret, err = re_match(line,"(HTTP/[0-3](\\.[0-1])?) ([1-5][0-9]{2}) ([A-Za-z ]+)") + + if not ret then + return nil, nil, nil, "failed to parse http status with error: "..err + end + + return ret[1], tonumber(ret[3]), ret[4] +end + + +local function _http_header_receive(sock) + local ret = {} + + repeat + local line, err = sock:receive("*l") + if not line then + return nil, err + end + + local m, err = re_match(line, "([^:\\s]+):\\s*(.*)", "jo") + if err then log(DEBUG, err) end + + if not m then + break + end + + local key = string.lower(m[1]) + local val = m[2] + + if ret[key] then + if type(ret[key]) ~= "table" then + ret[key] = { ret[key] } + end + tbl_insert(ret[key], tostring(val)) + else + ret[key] = tostring(val) + end + until re_find(line, "^\\s*$", "jo") + + return ret +end + + +local function _http_header_send(sock, host, method, length, param) + local hoststr + + if (host[4] and host[2] ~= 443) or (not host[4] and host[2] ~= 80) then + hoststr = host[1]..":"..host[2] + else + hoststr = host[1] + end + + local query = { + true, + 'Host: '..hoststr, + 'User-Agent: '..agent, + 'Accept: application/dns-message', + 'Connection: keep-alive' + } + + if method == nil or method == ngx.HTTP_GET then + query[1] = 'GET '.. host[3]..param.. ' HTTP/1.1' + query = concat(query,"\r\n").."\r\n\r\n" + elseif method == ngx.HTTP_POST then + query[1] = 'POST '.. host[3] .. ' HTTP/1.1' + insert(query,'Content-Length: '..length) + insert(query,'Content-Type: application/dns-message') + query = concat(query, "\r\n").."\r\n\r\n" + else + return nil, "unsupported method" + end + + local bytes, err = sock:send(query) + + if not bytes then + return 0, err + end + + return bytes +end + + +local function _http_body_receive(sock, header) + local len = header["content-length"] + + if header["content-type"] ~= "application/dns-message" then + return nil, "http query failed invalid Content-Type: "..header["content-type"] + end + + local data, err = sock:receiveany(tonumber(len)) + + if not data then + return nil, "http query failed to receive body "..err + end + + return data +end + + +local function _http_query(self,host,opts) + local sock, err = _http_connect(self, host) + + if not sock then + return nil, err + end + + local bytes, err = _http_header_send(sock, host, opts.method, opts.body and #opts.body or 0, opts.param) + + if not bytes then + return nil, err + end + + if opts.body then + local bytes, err = sock:send(opts.body) + if not bytes or bytes < #opts.body then + return nil, "http POST query failed body not sent" + end + end + + local version, status, reason, err = _http_status_receive(sock) + + if err then + return nil, err + end + + if status ~= 200 then + return nil, "http query failed status code is: "..status.." reason: "..reason + end + + local header, err = _http_header_receive(sock) + + if not header then + return nil, err + end + + local data, err = _http_body_receive(sock, header) + + if not data then + return nil, err + end + + sock:setkeepalive() + + return { + status = status, + version = version, + body = data + } +end + +local function _doh_query(self, qname, opts, tries) local retrans = self.retrans if tries then tries[1] = nil @@ -847,7 +964,7 @@ local function doh_query(self, qname, opts, tries) local servers = self.servers if #servers == 0 then - return nil, "No servers available" + return nil, "no servers available" end local err @@ -861,16 +978,20 @@ local function doh_query(self, qname, opts, tries) local res local id - + if self.doh == 'GET' then - res = ngx.location.capture(servers[idx][1] .. b64.encode_base64url(qname)) + res = _http_query(self,servers[idx], { method = ngx.HTTP_GET, param = b64.encode_base64url(qname) }) else id = _gen_id(self) local bdata = table.concat(_build_request(qname, id, self.no_recurse, opts)) - res = ngx.location.capture(servers[idx][1],{ method = ngx.HTTP_POST, body = bdata }) + res, err = _http_query(self,servers[idx],{ method = ngx.HTTP_POST, body = bdata }) end - if res ~= nil and res.status == 200 and res.body then + if not res then + return nil, err, tries + end + + if res.status == 200 and res.body then local answers if self.doh == 'GET' then local ident_hi = byte(res.body, 1) @@ -881,6 +1002,12 @@ local function doh_query(self, qname, opts, tries) if answers then return answers, nil, tries end + + if err and err ~= "id mismatch" then + break + else + log(DEBUG,"doh query failed to parse response",err) + end end if tries then @@ -893,7 +1020,7 @@ local function doh_query(self, qname, opts, tries) end -local function udp_tcp_query(self, qname, opts, tries) +local function _udp_tcp_query(self, qname, opts, tries) local socks = self.socks if not socks then return nil, "not initialized" @@ -966,13 +1093,6 @@ local function udp_tcp_query(self, qname, opts, tries) return nil, err, tries end -function _M.query(self, qname, opts, tries) - if self.doh then - return doh_query(self,qname,opts,tries) - end - - return udp_tcp_query(self,qname,opts,tries) -end function _M.compress_ipv6_addr(addr) local addr = re_sub(addr, "^(0:)+|(:0)+$|:(0:)+", "::", "jo") @@ -1048,4 +1168,124 @@ function _M.reverse_query(self, addr) end +local function _new_doh(class,opts) + if opts.doh ~= 'POST' and opts.doh ~= 'GET' then + return nil, "invalid DoH mode specified" + end + + local servers = opts.nameservers + local n = #servers + + for i = 1, n do + local captures, err = re_match(servers[i],"^((https?)(://))?([A-Za-z0-9\\.-]+)(:[1-9][0-9]*)?(/.+)$") + + if not captures then + return nil, err + end + + local host = captures[4] + local ssl = (captures[1] == 'https://') and true or false + local port + + if captures[5] then + port = tonumber(sub(captures[5],2)) + elseif not ssl then + port = 80 + else + port = 443 + end + + if not port then + return nil, "invalid port specified" + end + + servers[i] = { host, port, captures[6], ssl} + end + + _M.query = _doh_query + + return servers +end + +local function _new_tcp_udp(class,opts,timeout) + local servers = opts.nameservers + local n = #servers + local socks = {} + + for i = 1, n do + local server = servers[i] + local host, port, ssl + + if type(server) == 'table' then + host = server[1] + port = server[2] or 53 + else + host = server + port = 53 + servers[i] = {host, port} + end + + local sock, err = udp() + if not sock then + return nil, "failed to create udp socket: " .. err + end + + local ok, err = sock:setpeername(host, port) + if not ok then + return nil, "failed to set peer name: " .. err + end + + sock:settimeout(timeout) + + insert(socks, sock) + end + + _M.query = _udp_tcp_query + + return servers,socks +end + + +function _M.new(class, opts) + if not opts then + return nil, "no options table specified" + end + + local servers = opts.nameservers + if not servers or #servers == 0 then + return nil, "no nameservers specified" + end + + local timeout = opts.timeout or 2000 -- default 2 sec + local servers, socks, err + + if opts.doh then + servers, err = _new_doh(class,opts) + else + servers, socks, err = _new_tcp_udp(class,opts,timeout) + end + + if not servers then + return nil, err + end + + local tcp_sock, err = tcp() + if not tcp_sock then + return nil, "failed to create tcp socket: " .. err + end + + tcp_sock:settimeout(timeout) + + return setmetatable( + { cur = opts.no_random and 1 or rand(1, n), + socks = socks, + tcp_sock = tcp_sock, + servers = servers, + retrans = opts.retrans or 5, + no_recurse = opts.no_recurse, + doh = opts.doh + }, mt) +end + + return _M From 13e42fd31a6650a952a4d48c77c2bc956ebe7916 Mon Sep 17 00:00:00 2001 From: Pawel Ciejka Date: Wed, 17 Nov 2021 03:39:48 +0000 Subject: [PATCH 07/10] fix: opts.doh is now boolean, added opts.doh_method, updated README --- README.markdown | 5 ++++- lib/resty/dns/resolver.lua | 11 ++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/README.markdown b/README.markdown index 93b4b3c..3c6a0a0 100644 --- a/README.markdown +++ b/README.markdown @@ -139,8 +139,11 @@ It accepts a `opts` table argument. The following options are supported: a boolean flag controls whether to randomly pick the nameserver to query first, if `true` will always start with the first nameserver listed. Defaults to `false`. * `doh` + a boolean flag controls whether to use DNS over Https (DoH) - type of DoH query possible values are `POST`, `GET` or boolean false, Defaults to nil. +* `doh_method` + + type of DoH query possible values are `POST` or `GET` or boolean false, Defaults to nil. [Back to TOC](#table-of-contents) diff --git a/lib/resty/dns/resolver.lua b/lib/resty/dns/resolver.lua index 88b2eda..df25d71 100644 --- a/lib/resty/dns/resolver.lua +++ b/lib/resty/dns/resolver.lua @@ -979,7 +979,7 @@ local function _doh_query(self, qname, opts, tries) local res local id - if self.doh == 'GET' then + if self.doh_method == 'GET' then res = _http_query(self,servers[idx], { method = ngx.HTTP_GET, param = b64.encode_base64url(qname) }) else id = _gen_id(self) @@ -993,7 +993,7 @@ local function _doh_query(self, qname, opts, tries) if res.status == 200 and res.body then local answers - if self.doh == 'GET' then + if self.doh_method == 'GET' then local ident_hi = byte(res.body, 1) local ident_lo = byte(res.body, 2) id = lshift(ident_hi, 8) + ident_lo @@ -1169,7 +1169,7 @@ end local function _new_doh(class,opts) - if opts.doh ~= 'POST' and opts.doh ~= 'GET' then + if opts.doh_method ~= 'POST' and opts.doh_method ~= 'GET' then return nil, "invalid DoH mode specified" end @@ -1283,8 +1283,9 @@ function _M.new(class, opts) servers = servers, retrans = opts.retrans or 5, no_recurse = opts.no_recurse, - doh = opts.doh - }, mt) + doh = opts.doh, + doh_method = opts.doh_method + }, mt) end From edb1932ceaf3a65a997897917062ffc9c02b665f Mon Sep 17 00:00:00 2001 From: Pawel Ciejka Date: Wed, 17 Nov 2021 05:07:18 +0000 Subject: [PATCH 08/10] fix: DoH method translation from string to ngx constant --- lib/resty/dns/resolver.lua | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/lib/resty/dns/resolver.lua b/lib/resty/dns/resolver.lua index df25d71..56d4e4a 100644 --- a/lib/resty/dns/resolver.lua +++ b/lib/resty/dns/resolver.lua @@ -979,7 +979,7 @@ local function _doh_query(self, qname, opts, tries) local res local id - if self.doh_method == 'GET' then + if self.doh_method == ngx.HTTP_GET then res = _http_query(self,servers[idx], { method = ngx.HTTP_GET, param = b64.encode_base64url(qname) }) else id = _gen_id(self) @@ -993,7 +993,7 @@ local function _doh_query(self, qname, opts, tries) if res.status == 200 and res.body then local answers - if self.doh_method == 'GET' then + if self.doh_method == ngx.HTTP_GET then local ident_hi = byte(res.body, 1) local ident_lo = byte(res.body, 2) id = lshift(ident_hi, 8) + ident_lo @@ -1169,8 +1169,13 @@ end local function _new_doh(class,opts) - if opts.doh_method ~= 'POST' and opts.doh_method ~= 'GET' then - return nil, "invalid DoH mode specified" + local method + if opts.doh_method == 'POST' then + method = ngx.HTTP_POST + elseif opts.doh_method == 'GET' then + method = ngx.HTTP_GET + else + return nil, nil, "invalid DoH mode specified" end local servers = opts.nameservers @@ -1180,7 +1185,7 @@ local function _new_doh(class,opts) local captures, err = re_match(servers[i],"^((https?)(://))?([A-Za-z0-9\\.-]+)(:[1-9][0-9]*)?(/.+)$") if not captures then - return nil, err + return nil, nil, err end local host = captures[4] @@ -1196,7 +1201,7 @@ local function _new_doh(class,opts) end if not port then - return nil, "invalid port specified" + return nil, nil, "invalid port specified" end servers[i] = { host, port, captures[6], ssl} @@ -1204,7 +1209,7 @@ local function _new_doh(class,opts) _M.query = _doh_query - return servers + return servers, method end local function _new_tcp_udp(class,opts,timeout) @@ -1258,9 +1263,10 @@ function _M.new(class, opts) local timeout = opts.timeout or 2000 -- default 2 sec local servers, socks, err + local method if opts.doh then - servers, err = _new_doh(class,opts) + servers, method, err = _new_doh(class,opts) else servers, socks, err = _new_tcp_udp(class,opts,timeout) end @@ -1275,7 +1281,7 @@ function _M.new(class, opts) end tcp_sock:settimeout(timeout) - + return setmetatable( { cur = opts.no_random and 1 or rand(1, n), socks = socks, @@ -1284,7 +1290,7 @@ function _M.new(class, opts) retrans = opts.retrans or 5, no_recurse = opts.no_recurse, doh = opts.doh, - doh_method = opts.doh_method + doh_method = method }, mt) end From 40f7d668b23f37a3cbd8f52922aba3fa8f344078 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Ciejka?= Date: Wed, 17 Nov 2021 05:17:02 +0000 Subject: [PATCH 09/10] bugfix: number of servers passed to rand --- lib/resty/dns/resolver.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/resty/dns/resolver.lua b/lib/resty/dns/resolver.lua index 56d4e4a..e962526 100644 --- a/lib/resty/dns/resolver.lua +++ b/lib/resty/dns/resolver.lua @@ -1283,7 +1283,7 @@ function _M.new(class, opts) tcp_sock:settimeout(timeout) return setmetatable( - { cur = opts.no_random and 1 or rand(1, n), + { cur = opts.no_random and 1 or rand(1, #servers), socks = socks, tcp_sock = tcp_sock, servers = servers, From ddbbe2b76b3c137d6c9448e0107664401207be3c Mon Sep 17 00:00:00 2001 From: paparuch Date: Fri, 30 Sep 2022 03:54:15 +0000 Subject: [PATCH 10/10] incomplete implementation of any dns request --- README.markdown | 4 + lib/resty/dns/resolver.lua | 1542 ++++++++++++++++------------------ lib/resty/dns/wireformat.lua | 596 +++++++++++++ t/mock.t | 94 +++ 4 files changed, 1427 insertions(+), 809 deletions(-) create mode 100644 lib/resty/dns/wireformat.lua diff --git a/README.markdown b/README.markdown index 3c6a0a0..ad809cc 100644 --- a/README.markdown +++ b/README.markdown @@ -139,11 +139,15 @@ It accepts a `opts` table argument. The following options are supported: a boolean flag controls whether to randomly pick the nameserver to query first, if `true` will always start with the first nameserver listed. Defaults to `false`. * `doh` + a boolean flag controls whether to use DNS over Https (DoH) * `doh_method` type of DoH query possible values are `POST` or `GET` or boolean false, Defaults to nil. +* `doh_bootstrap` + + list of nameservers used to perform initial query for IP address of DoH servers. [Back to TOC](#table-of-contents) diff --git a/lib/resty/dns/resolver.lua b/lib/resty/dns/resolver.lua index e962526..351ff5f 100644 --- a/lib/resty/dns/resolver.lua +++ b/lib/resty/dns/resolver.lua @@ -2,8 +2,29 @@ -- local socket = require "socket" -local bit = require "bit" + +local ok, b64 = pcall(require,"ngx.base64") +if not ok then + return false +end + +local ok, bit = pcall(require, "bit") +if not ok then + return false +end + +local ok, wire = pcall(require, "resty.dns.wireformat") +if not ok then + return false +end + +local ok, new_tab = pcall(require, "table.new") +if not ok then + new_tab = function (narr, nrec) return {} end +end + local udp = ngx.socket.udp +local tcp = ngx.socket.tcp local rand = math.random local char = string.char local byte = string.byte @@ -12,92 +33,34 @@ local gsub = string.gsub local sub = string.sub local rep = string.rep local format = string.format -local band = bit.band -local rshift = bit.rshift -local lshift = bit.lshift local insert = table.insert local concat = table.concat local re_sub = ngx.re.sub local re_match = ngx.re.match local re_find = ngx.re.find -local tcp = ngx.socket.tcp local log = ngx.log local DEBUG = ngx.DEBUG local unpack = unpack local setmetatable = setmetatable local type = type local ipairs = ipairs -local b64 = require "ngx.base64" local agent = "ngx_lua/" .. ngx.config.ngx_lua_version local str_lower = string.lower -local ngx_ERR = ngx.ERR -local tbl_insert = table.insert local tolower = string.lower +local ngx_get = ngx.HTTP_GET +local ngx_post = ngx.HTT_POST +local band = bit.band +local wire_build = wire.build_request +local wire_parse = wire.parse_response +local bit = require "bit" +local band = bit.band +local rshift = bit.rshift +local lshift = bit.lshift -local ok, new_tab = pcall(require, "table.new") -if not ok then - new_tab = function (narr, nrec) return {} end -end - - -local DOT_CHAR = byte(".") -local ZERO_CHAR = byte("0") -local COLON_CHAR = byte(":") +local arpa_tmpl = new_tab(72, 0) local IP6_ARPA = "ip6.arpa" -local TYPE_A = 1 -local TYPE_NS = 2 -local TYPE_CNAME = 5 -local TYPE_SOA = 6 -local TYPE_PTR = 12 -local TYPE_MX = 15 -local TYPE_TXT = 16 -local TYPE_AAAA = 28 -local TYPE_SRV = 33 -local TYPE_SPF = 99 - -local CLASS_IN = 1 - -local SECTION_AN = 1 -local SECTION_NS = 2 -local SECTION_AR = 3 - - -local _M = { - _VERSION = '0.22', - TYPE_A = TYPE_A, - TYPE_NS = TYPE_NS, - TYPE_CNAME = TYPE_CNAME, - TYPE_SOA = TYPE_SOA, - TYPE_PTR = TYPE_PTR, - TYPE_MX = TYPE_MX, - TYPE_TXT = TYPE_TXT, - TYPE_AAAA = TYPE_AAAA, - TYPE_SRV = TYPE_SRV, - TYPE_SPF = TYPE_SPF, - CLASS_IN = CLASS_IN, - SECTION_AN = SECTION_AN, - SECTION_NS = SECTION_NS, - SECTION_AR = SECTION_AR -} - - -local resolver_errstrs = { - "format error", -- 1 - "server failure", -- 2 - "name error", -- 3 - "not implemented", -- 4 - "refused", -- 5 -} - -local soa_int32_fields = { "serial", "refresh", "retry", "expire", "minimum" } - -local mt = { __index = _M } - - -local arpa_tmpl = new_tab(72, 0) - for i = 1, #IP6_ARPA do arpa_tmpl[64 + i] = byte(IP6_ARPA, i) end @@ -106,682 +69,675 @@ for i = 2, 64, 2 do arpa_tmpl[i] = DOT_CHAR end +local COLON_CHAR = byte(":") -local function pick_sock(self, socks) - local cur = self.cur - - if cur == #socks then - self.cur = 1 - else - self.cur = cur + 1 - end - - return socks[cur] -end - - -local function _get_cur_server(self) - local cur = self.cur - - local servers = self.servers - - if cur == 1 then - return servers[#servers] - end - - return servers[cur - 1] -end +local _M = { + _VERSION = '0.22', + TYPE_A = wire.TYPE.A, + TYPE_NS = wire.TYPE.NS, + TYPE_CNAME = wire.TYPE.CNAME, + TYPE_SOA = wire.TYPE.SOA, + TYPE_PTR = wire.TYPE.PTR, + TYPE_MX = wire.TYPE.MX, + TYPE_TXT = wire.TYPE.TXT, + TYPE_AAAA = wire.TYPE.AAAA, + TYPE_SRV = wire.TYPE.SRV, + TYPE_SPF = wire.TYPE.SPF, + CLASS_IN = wire.CLASS.IN, + SECTION_AN = wire.SECTION.AN, + SECTION_NS = wire.SECTION.NS, + SECTION_AR = wire.SECTION.AR, + MODE = { + UDP = 1, + TCP = 2, + UDP_TCP = 3, + DOT = 4, + DOH = 8 + } +} +local MODE_UDP = _M.MODE.UDP +local MODE_TCP = _M.MODE.TCP +local MODE_DOT = _M.MODE.DOT +local MODE_DOH = _M.MODE.DOH +local MODE_UDP_TCP = _M.MODE.UDP_TCP -function _M.set_timeout(self, timeout) - local socks = self.socks - if not socks then - return nil, "not initialized" - end +local DOH_METHOD = { + GET = ngx_get, + POST = ngx_post +} - for i = 1, #socks do - local sock = socks[i] - sock:settimeout(timeout) +local function _is_ip(str) + if type(str) ~= "string" then + return false end - - local tcp_sock = self.tcp_sock - if not tcp_sock then - return nil, "not initialized" + + local ret, err = re_match(str,"(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)") + + if ret then + return true end - - tcp_sock:settimeout(timeout) + + return false, err end - -local function _encode_name(s) - return char(#s) .. s +local function _gen_id() -- (self) + --local id = self._id -- for regression testing + --if id then + -- return id + --end + return rand(0, 65535) -- two bytes end +----------------------[[ private implementation ]]---------------------------------------- -local function _decode_name(buf, pos) - local labels = {} - local nptrs = 0 - local p = pos - while nptrs < 128 do - local fst = byte(buf, p) - - if not fst then - return nil, 'truncated'; - end - - -- print("fst at ", p, ": ", fst) - - if fst == 0 then - if nptrs == 0 then - pos = pos + 1 - end - break - end +local function _build_wire_request(self, qname, id, no_recurse, opts) + return wire_build(qname,id,no_recurse,opts) +end - if band(fst, 0xc0) ~= 0 then - -- being a pointer - if nptrs == 0 then - pos = pos + 2 - end +local function _parse_wire_response(self, data, id, opts) + return wire_parse(data, id, opts) +end - nptrs = nptrs + 1 +local function _build_post_wire_request(self, qname, id, no_recurse, opts) + local opts = { + method = ngx_post, + body = wire_build(qname, id, self.no_recurse, opts) + } + + return id, opts +end - local snd = byte(buf, p + 1) - if not snd then - return nil, 'truncated' - end +local function _build_post_json_request(self, qname, id, no_recurse, opts) + local opts = {} + + return opts +end - p = lshift(band(fst, 0x3f), 8) + snd + 1 +local function _build_get_json_request(self, qname, id, no_recurse, opts) + local opts = { + method = ngx_get, + param = self.doh_encode and b64.encode_base64url(qname) or qname + } + + return opts +end - -- print("resolving ptr ", p, ": ", byte(buf, p)) +local function _build_get_wire_request(self, qname, id, no_recurse, opts) + local opts = { + method = ngx_get, + param = self.doh_encode and b64.encode_base64url(qname) or qname + } + + return opts +end - else - -- being a label - local label = sub(buf, p + 1, p + fst) - insert(labels, label) +local function _parse_doh_wire_response(self, qname, id, no_recurse, opts) + + +end - -- print("resolved label ", label) +local function _parse_doh_json_response(self, qname, id, no_recurse, opts) + +end - p = p + fst + 1 - if nptrs == 0 then - pos = p - end - end - end +--[[ sockets implementation ]]-- - return concat(labels, "."), pos +local function _sock_write(self, data) + return self.fd:send(data) end +local function _sock_read(self) + return self.fd:receive() +end -local function _build_request(qname, id, no_recurse, opts) - local qtype +local function _sock_close(self) + return self.fd:close() +end - if opts then - qtype = opts.qtype - end +local function _sock_settimeout(self,timeout) + return self:settimeout(timeout) +end - if not qtype then - qtype = 1 -- A record +--[[ udp stream implementation ]]-- +local function _udp_open(self, host, port, opts, ip) + local fd = udp() + local addr = ip or self.ip or host or self.host + local pnum = port or self.port + local setts = opts or self.opts + local timeout = setts and setts.timeout or 2000 + + fd:settimeout(timeout) + + local ok, err = fd:setpeername(addr,pnum) + if not ok then + return false, err end + + self.fd = fd + --self.id = _gen_id() + + return true +end - local ident_hi = char(rshift(id, 8)) - local ident_lo = char(band(id, 0xff)) - local flags - if no_recurse then - -- print("found no recurse") - flags = "\0\0" - else - flags = "\1\0" - end +local _udp_stream_mt = { + open = _udp_open, --_udp_cached_open(), + read = _sock_read, + write = _sock_write, + close = _sock_close, + settimeout = _sock_settimeout +} - local nqs = "\0\1" - local nan = "\0\0" - local nns = "\0\0" - local nar = "\0\0" - local typ = char(rshift(qtype, 8), band(qtype, 0xff)) - local class = "\0\1" -- the Internet class +--[[ tcp stream implementation ]]-- +local function _tcp_open(self, host, port, opts, ip) + local fd = tcp() + local addr = ip or self.ip or host or self.host + local pnum = port or self.port + local setts = opts or self.opts + local timeout = setts and setts.timeout or 2000 + + fd:settimeout(timeout) - if byte(qname, 1) == DOT_CHAR then - return nil, "bad name" + local ok, err = fd:connect(addr,pnum) + if not ok then + return false, err end - - local name = gsub(qname, "([^.]+)%.?", _encode_name) .. '\0' - - return { - ident_hi, ident_lo, flags, nqs, nan, nns, nar, - name, typ, class - } + + self.fd = fd + --self.id = _gen_id() + + return true end +local function _tcp_sock_write(self, data) + local query = concat(data,'') + local len = #query + local len_hi = char(rshift(len, 8)) + local len_lo = char(band(len, 0xff)) + + return self.fd:send({len_hi, len_lo, query}) +end -local function parse_section(answers, section, buf, start_pos, size, - should_skip) - local pos = start_pos - - for _ = 1, size do - -- print(format("ans %d: qtype:%d qclass:%d", i, qtype, qclass)) - local ans = {} +local function _tcp_sock_read(self) + local buf, err = self.fd:receive(2) + local len_hi = byte(buf, 1) + local len_lo = byte(buf, 2) + local len = lshift(len_hi, 8) + len_lo + + return self.fd:receive(len) +end - if not should_skip then - insert(answers, ans) - end +local _tcp_stream_mt = { + open = _tcp_open, + read = _tcp_sock_read, + write = _tcp_sock_write, + close = _sock_close, + settimeout = _sock_settimeout +} - ans.section = section +-------------[[ encrypted ssl/tls tcp stream ]]------------ - local name - name, pos = _decode_name(buf, pos) - if not name then - return nil, pos +local function _enc_tcp_open(self, host, port, opts, ip) + local addr = ip or host + local ok, err = _tcp_open(self,addr,port,opts,ip) + if not ok then + return false, err + end + + if self.fd:getreusedtimes() == 0 then + local session, err = self.fd:sslhandshake(nil,host) + if not session then + return false, err end + end + + return true +end - ans.name = name - - -- print("name: ", name) - - local type_hi = byte(buf, pos) - local type_lo = byte(buf, pos + 1) - local typ = lshift(type_hi, 8) + type_lo - - ans.type = typ - - -- print("type: ", typ) - - local class_hi = byte(buf, pos + 2) - local class_lo = byte(buf, pos + 3) - local class = lshift(class_hi, 8) + class_lo - - ans.class = class - - -- print("class: ", class) - - local byte_1, byte_2, byte_3, byte_4 = byte(buf, pos + 4, pos + 7) - - local ttl = lshift(byte_1, 24) + lshift(byte_2, 16) - + lshift(byte_3, 8) + byte_4 - - -- print("ttl: ", ttl) - - ans.ttl = ttl - - local len_hi = byte(buf, pos + 8) - local len_lo = byte(buf, pos + 9) - local len = lshift(len_hi, 8) + len_lo - - -- print("record len: ", len) - - pos = pos + 10 - - if typ == TYPE_A then - - if len ~= 4 then - return nil, "bad A record value length: " .. len - end - - local addr_bytes = { byte(buf, pos, pos + 3) } - local addr = concat(addr_bytes, ".") - -- print("ipv4 address: ", addr) - - ans.address = addr - - pos = pos + 4 - - elseif typ == TYPE_CNAME then - - local cname, p = _decode_name(buf, pos) - if not cname then - return nil, pos - end - - if p - pos ~= len then - return nil, format("bad cname record length: %d ~= %d", - p - pos, len) - end - - pos = p - - -- print("cname: ", cname) - - ans.cname = cname - - elseif typ == TYPE_AAAA then - - if len ~= 16 then - return nil, "bad AAAA record value length: " .. len - end - - local addr_bytes = { byte(buf, pos, pos + 15) } - local flds = {} - for i = 1, 16, 2 do - local a = addr_bytes[i] - local b = addr_bytes[i + 1] - if a == 0 then - insert(flds, format("%x", b)) - - else - insert(flds, format("%x%02x", a, b)) - end - end - - -- we do not compress the IPv6 addresses by default - -- due to performance considerations - - ans.address = concat(flds, ":") - - pos = pos + 16 - - elseif typ == TYPE_MX then - - -- print("len = ", len) - - if len < 3 then - return nil, "bad MX record value length: " .. len - end - - local pref_hi = byte(buf, pos) - local pref_lo = byte(buf, pos + 1) - - ans.preference = lshift(pref_hi, 8) + pref_lo - - local host, p = _decode_name(buf, pos + 2) - if not host then - return nil, pos - end - - if p - pos ~= len then - return nil, format("bad cname record length: %d ~= %d", - p - pos, len) - end - - ans.exchange = host - - pos = p - - elseif typ == TYPE_SRV then - if len < 7 then - return nil, "bad SRV record value length: " .. len - end - - local prio_hi = byte(buf, pos) - local prio_lo = byte(buf, pos + 1) - ans.priority = lshift(prio_hi, 8) + prio_lo - - local weight_hi = byte(buf, pos + 2) - local weight_lo = byte(buf, pos + 3) - ans.weight = lshift(weight_hi, 8) + weight_lo - - local port_hi = byte(buf, pos + 4) - local port_lo = byte(buf, pos + 5) - ans.port = lshift(port_hi, 8) + port_lo - - local name, p = _decode_name(buf, pos + 6) - if not name then - return nil, pos - end - - if p - pos ~= len then - return nil, format("bad srv record length: %d ~= %d", - p - pos, len) - end - - ans.target = name - - pos = p - - elseif typ == TYPE_NS then - - local name, p = _decode_name(buf, pos) - if not name then - return nil, pos - end - - if p - pos ~= len then - return nil, format("bad cname record length: %d ~= %d", - p - pos, len) - end - - pos = p +local _enc_tcp_stream_mt = { + open = _enc_tcp_open, + read = _sock_read, + write = _sock_write, + close = _sock_close, + settimeout = _sock_settimeout +} - -- print("name: ", name) +-----------------[[ streams ]]--------------------------- - ans.nsdname = name +local function _new_stream_int(class,mt) + return setmetatable({ + host = class.host, + port = class.port, + ip = class.ip + },{ __index = mt}) +end - elseif typ == TYPE_TXT or typ == TYPE_SPF then - local key = (typ == TYPE_TXT) and "txt" or "spf" +local function _new_udp_stream(class) + return _new_stream_int(class,_udp_stream_mt) +end - local slen = byte(buf, pos) - if slen + 1 > len then - -- truncate the over-run TXT record data - slen = len - end - -- print("slen: ", len) - - local val = sub(buf, pos + 1, pos + slen) - local last = pos + len - pos = pos + slen + 1 - - if pos < last then - -- more strings to be processed - -- this code path is usually cold, so we do not - -- merge the following loop on this code path - -- with the processing logic above. - - val = {val} - local idx = 2 - repeat - local slen = byte(buf, pos) - if pos + slen + 1 > last then - -- truncate the over-run TXT record data - slen = last - pos - 1 - end +local function _new_tcp_stream(class) + return _new_stream_int(class,_tcp_stream_mt) +end - val[idx] = sub(buf, pos + 1, pos + slen) - idx = idx + 1 - pos = pos + slen + 1 - until pos >= last - end +local function _new_enc_stream(class) + return _new_stream_int(class,_enc_tcp_stream_mt) +end - ans[key] = val - elseif typ == TYPE_PTR then +local _udp_pimpl = { + build = _build_wire_request, -- (qname, id, no_recurse, opts) + parse = _parse_wire_response, --(buf, id, opts), + stream = _new_udp_stream +} - local name, p = _decode_name(buf, pos) - if not name then - return nil, pos - end +local _tcp_pimpl = { + build = _build_wire_request, -- (qname, id, no_recurse, opts) + parse = _parse_wire_response, --(buf, id, opts), + stream = _new_tcp_stream +} - if p - pos ~= len then - return nil, format("bad cname record length: %d ~= %d", - p - pos, len) - end +local _udp_tcp_pimpl = { + build = _build_wire_request, -- (qname, id, no_recurse, opts) + parse = _parse_wire_response, --(buf, id, opts), + stream = _new_udp_stream +} - pos = p +local _dot_pimpl = { + build = _build_wire_request, -- (qname, id, no_recurse, opts) + parse = _parse_wire_response, -- (buf, id, opts), + stream = _new_enc_stream +} - -- print("name: ", name) +local _doh_wire_get_pimpl = { + build = _build_get_wire_request, + parse = _parse_wire_response, + stream = _new_enc_stream +} - ans.ptrdname = name +local _doh_json_get_pimpl = { + build = _build_get_json_request, + parse = _parse_json_response, + stream = _new_enc_stream +} - elseif typ == TYPE_SOA then - local name, p = _decode_name(buf, pos) - if not name then - return nil, pos - end - ans.mname = name +local _doh_wire_post_pimpl = { + build = _build_post_wire_request, + parse = _parse_doh_wire_response, + stream = _new_enc_stream +} - pos = p - name, p = _decode_name(buf, pos) - if not name then - return nil, pos - end - ans.rname = name +local doh_json_post_pimpl = { + build = _build_post_json_request, + parse = _parse_doh_json_response, + stream = _new_enc_stream +} - for _, field in ipairs(soa_int32_fields) do - local byte_1, byte_2, byte_3, byte_4 = byte(buf, p, p + 3) - ans[field] = lshift(byte_1, 24) + lshift(byte_2, 16) - + lshift(byte_3, 8) + byte_4 - p = p + 4 - end +---------------------------[[ server parsers ]]------------------------------- - pos = p +local function _udp_tcp_server_parser_int(server, opts, pimpl, mode) + local host, port + + if type(server) == 'table' then + host = server[1] + port = server[2] or 53 + else + host = server + port = 53 + end - else - -- for unknown types, just forward the raw value + return setmetatable({ + host = host, + port = port, + mode = mode + }, { __index = pimpl }) +end - ans.rdata = sub(buf, pos, pos + len - 1) - pos = pos + len - end - end - return pos +local function _udp_server_parser(server, opts) + return _udp_tcp_server_parser_int(server,opts,_udp_pimpl, MODE_UDP) end -local function parse_response(buf, id, opts) - local n = #buf - if n < 12 then - return nil, 'truncated'; - end +local function _tcp_server_parser(server, opts) + return _udp_tcp_server_parser_int(server,opts,_tcp_pimpl, MODE_TCP) +end - -- header layout: ident flags nqs nan nns nar - local ident_hi = byte(buf, 1) - local ident_lo = byte(buf, 2) - local ans_id = lshift(ident_hi, 8) + ident_lo +local function _udp_tcp_server_parser(server, opts) + return _udp_tcp_server_parser_int(server,opts,_udp_pimpl, MODE_UDP_TCP) +end - -- print("id: ", id, ", ans id: ", ans_id) - if ans_id ~= id then - -- identifier mismatch and throw it away - log(DEBUG, "id mismatch in the DNS reply: ", ans_id, " ~= ", id) - return nil, "id mismatch" +local function _dot_server_parser(server, opts) + local res, err = _tcp_servers_parser(server,opts) + if not res then + return nil, err end + + return setmetatable({ + host = host, + port = port, + mode = MODE_DOT + }, { __index = _dot_pimpl }) +end - local flags_hi = byte(buf, 3) - local flags_lo = byte(buf, 4) - local flags = lshift(flags_hi, 8) + flags_lo - - -- print(format("flags: 0x%x", flags)) - if band(flags, 0x8000) == 0 then - return nil, format("bad QR flag in the DNS response") +local function _doh_server_parser(server, opts) + local method = (type(server) == 'table') and server.method or 'GET' + method = DOH_METHOD[method] + if not method then + return false, "invalid DoH mode specified" + end + + local res, err = _tcp_server_parser(server, opts) + if not res then + return false, err end - - if band(flags, 0x200) ~= 0 then - return nil, "truncated" + + local url + local method + local ct + local ac + + if type(server) == 'table' then + url = server[1] or server.url + method = server[2] or server.method or ngx_get + ct = server[3] or server.ct or 'application/dns-message' + ac = server[4] or server.ac or 'application/dns-message' + else + url = server + method = ngx_get + ct = 'application/dns-message' + ac = 'application/dns-message' end - - local code = band(flags, 0xf) - - -- print(format("code: %d", code)) - - local nqs_hi = byte(buf, 5) - local nqs_lo = byte(buf, 6) - local nqs = lshift(nqs_hi, 8) + nqs_lo - - -- print("nqs: ", nqs) - - if nqs ~= 1 then - return nil, format("bad number of questions in DNS response: %d", nqs) + + local captures, err = re_match(url,"^((https?)(://))?([A-Za-z0-9\\.-]+)(:[1-9][0-9]*)?(/.+)$") + if not captures then + return false, err + end + + local host = captures[4] + local ssl = (captures[1] == 'https://') and true or false + local port + + if captures[5] then + port = tonumber(sub(captures[5],2)) + elseif not ssl then + port = 80 + else + port = 443 end - - local nan_hi = byte(buf, 7) - local nan_lo = byte(buf, 8) - local nan = lshift(nan_hi, 8) + nan_lo - - -- print("nan: ", nan) - - local nns_hi = byte(buf, 9) - local nns_lo = byte(buf, 10) - local nns = lshift(nns_hi, 8) + nns_lo - - local nar_hi = byte(buf, 11) - local nar_lo = byte(buf, 12) - local nar = lshift(nar_hi, 8) + nar_lo - - -- skip the question part - - local ans_qname, pos = _decode_name(buf, 13) - if not ans_qname then - return nil, pos + + if not port then + return false, "invalid port specified" end - - -- print("qname in reply: ", ans_qname) - - -- print("question: ", sub(buf, 13, pos)) - - if pos + 3 + nan * 12 > n then - -- print(format("%d > %d", pos + 3 + nan * 12, n)) - return nil, 'truncated'; + + local hoststr + if (ssl and port ~= 443) or (not ssl and port ~= 80) then + hoststr = host..":"..port + else + hoststr = host end + + local query = { + 'Host: '..hoststr..'\r\n', + 'User-Agent: '..agent..'\r\n', + 'Connection: keep-alive'..'\r\n', + 'Accept: '..ac..'\r\n' + } - -- question section layout: qname qtype(2) qclass(2) - - --[[ - local type_hi = byte(buf, pos) - local type_lo = byte(buf, pos + 1) - local ans_type = lshift(type_hi, 8) + type_lo - ]] - - -- print("ans qtype: ", ans_type) + insert(query,(method == ngx_post) and 'Content-Type: '..ct..'\r\n' or "\r\n") - local class_hi = byte(buf, pos + 2) - local class_lo = byte(buf, pos + 3) - local qclass = lshift(class_hi, 8) + class_lo + return setmetatable({ + host = host, + port = port, + url = captures[6], + ssl = ssl, + query = query, + mode = MODE_DOH + }, { __index = _doh_pimpl }) +end - -- print("ans qclass: ", qclass) +local _server_parser_tbl = { + { MODE_UDP, _udp_server_parser }, + { MODE_TCP, _tcp_server_parser }, + { MODE_UDP_TCP, _udp_tcp_server_parser }, + { MODE_DOT, _dot_server_parser }, + { MODE_DOH, _doh_server_parser } +} - if qclass ~= 1 then - return nil, format("unknown query class %d in DNS response", qclass) - end +----------------------------[[ servers array ]]------------------------------------ - pos = pos + 4 +local function _servers_at(self, at) + return self.servers[at] +end - local answers = {} - if code ~= 0 then - answers.errcode = code - answers.errstr = resolver_errstrs[code] or "unknown" - end +local function _servers_size(self) + return #self.servers +end - local authority_section, additional_section - if opts then - authority_section = opts.authority_section - additional_section = opts.additional_section - if opts.qtype == TYPE_SOA then - authority_section = true +local function _servers_array_new(opts) + log(ngx.ERR,"NEW SERVER ARRAY") + + local servers = opts.nameservers + local pservers = {} + local n = #servers + local pn = #_server_parser_tbl + + for i = 1, n do + local server = servers[i] + local mode = (type(server) == 'table') and server.mode or MODE_UDP_TCP + local pserver_tbl, err + + for k = 1, pn do + local f_tbl = _server_parser_tbl[k] + + if f_tbl[1] == mode then + local pserver_tbl, err = f_tbl[2](server, opts) + + if not pserver_tbl then + return nil, "failed to create server at: "..k.."with error: "..err + end + + insert(pservers,pserver_tbl) + break + end end + end + + local servers_mt = { + size = _servers_size, + at = _servers_at + } + + local servers_tbl = { + current = no_random and 1 or rand(1, #servers), + servers = pservers + } + + return setmetatable(servers_tbl, { __index = servers_mt }) +end - local err +-------------------------------------------------------------------------------- - pos, err = parse_section(answers, SECTION_AN, buf, pos, nan) +local function _generic_query() - if not pos then - return nil, err - end +end - if not authority_section and not additional_section then - return answers +local function answers__to_string(self) + local ret ='' + for k,v in pairs(self) do + local typ = type(v) + if typ ~= 'function' then + if typ == 'table' then + ret = ret..'[\r\n' + ret = ret..answers__to_string(v) + ret = ret..']\r\n' + else + ret = ret..k..': '..v..'\r\n' + end + end + end + return ret end - pos, err = parse_section(answers, SECTION_NS, buf, pos, nns, - not authority_section) - - if not pos then - return nil, err - end +local answers_mt = { + __tostring = answers__to_string +} - if not additional_section then - return answers +--[[ +Perform DNS TCP query over connected socket +]] +local function _tcp_query(self, server, qname, no_recurse, opts) + if server == nil or qname == nil then + return nil, 'invalid arguments', nil end - - pos, err = parse_section(answers, SECTION_AR, buf, pos, nar) - - if not pos then - return nil, err + + local srv, err = _tcp_server_parser(server, opts) + if srv == nil then + return nil, err, nil end - - return answers -end - - -local function _gen_id(self) - local id = self._id -- for regression testing - if id then - return id + + local id = _gen_id() + local query, err = srv:build(qname,id,no_recurse, opts) + if query == nil then + return nil, err, nil end - return rand(0, 65535) -- two bytes -end - - -local function _tcp_query(self, query, id, opts) - local sock = self.tcp_sock - if not sock then - return nil, "not initialized" + + local stream, err = srv:stream() + if stream == nil then + return nil, err, nil end - log(DEBUG, "query the TCP server due to reply truncation") - - local server = _get_cur_server(self) - - local ok, err = sock:connect(server[1], server[2]) + local ok, err = stream:open() if not ok then - return nil, "failed to connect to TCP server " - .. concat(server, ":") .. ": " .. err + return nil, err, nil end - - query = concat(query, "") - local len = #query - - local len_hi = char(rshift(len, 8)) - local len_lo = char(band(len, 0xff)) - - local bytes, err = sock:send({len_hi, len_lo, query}) + + local bytes, err = stream:write(query) if not bytes then return nil, "failed to send query to TCP server " - .. concat(server, ":") .. ": " .. err + .. stream.host .. ":" .. stream.port .. ": " .. err, nil end - local buf, err = sock:receive(2) + local buf, err = stream:read() if not buf then return nil, "failed to receive the reply length field from TCP server " - .. concat(server, ":") .. ": " .. err + .. stream.host, ":" .. stream.port.. ": " .. err, {} + end + + local answers, err = srv:parse(buf,id) + if not answers then + return nil, err end + + return setmetatable(answers,answers_mt), nil, {} +end - len_hi = byte(buf, 1) - len_lo = byte(buf, 2) - len = lshift(len_hi, 8) + len_lo +local function _udp_query(self, server, qname, no_recurse, opts) + if server == nil or qname == nil then + return nil, 'invalid arguments', nil + end + + local srv, err = _udp_server_parser(server, opts) + if srv == nil then + return nil, err, nil + end + + local id = _gen_id() + local query, err = srv:build(qname,id,no_recurse, opts) + if query == nil then + return nil, err, nil + end + + local stream, err = srv:stream() + if stream == nil then + return nil, err, nil + end - -- print("tcp message len: ", len) + local ok, err = stream:open() + if not ok then + return nil, err, nil + end + + local bytes, err = stream:write(query) + if not bytes then + return nil, "failed to send query to UDP server " + .. stream.host .. ":" .. stream.port .. ": " .. err, nil + end - buf, err = sock:receive(len) + local buf, err = stream:read() if not buf then - return nil, "failed to receive the reply message body from TCP server " - .. concat(server, ":") .. ": " .. err + return nil, "failed to receive the reply UDP server " + .. stream.host, ":" .. stream.port.. ": " .. err, {} end - local answers, err = parse_response(buf, id, opts) + local answers, err = srv:parse(buf,id) if not answers then - return nil, "failed to parse the reply from the TCP server " - .. concat(server, ":") .. ": " .. err + return nil, err end - sock:close() - - return answers + return setmetatable(answers,answers_mt) end -function _M.tcp_query(self, qname, opts) - local socks = self.socks - if not socks then - return nil, "not initialized" +local function _dot_query(self, server, qname, no_recurse, opts) + if server == nil or qname == nil then + return nil, 'invalid arguments', nil + end + + local srv, err = _dot_server_parser(server, opts) + if srv == nil then + return nil, err, nil + end + + local id = _gen_id() + local query, err = srv:build(qname,id,no_recurse, opts) + if query == nil then + return nil, err, nil + end + + local stream, err = srv:stream() + if stream == nil then + return nil, err, nil end - pick_sock(self, socks) + local ok, err = stream:open() + if not ok then + return nil, err, nil + end + + local bytes, err = stream:write(query) + if not bytes then + return nil, "failed to send query to DoT server " + .. stream.host .. ":" .. stream.port .. ": " .. err, nil + end - local id = _gen_id(self) + local buf, err = stream:read() + if not buf then + return nil, "failed to receive the reply DoT server " + .. stream.host, ":" .. stream.port.. ": " .. err, {} + end - local query, err = _build_request(qname, id, self.no_recurse, opts) - if not query then + local answers, err = srv:parse(buf,id) + if not answers then return nil, err end - return _tcp_query(self, query, id, opts) + return setmetatable(answers,answers_mt) end -local function _http_connect(self,host) - local sock = self.tcp_sock - if not sock then - return nil, "not initialized" - end - +--[[ local function _http_connect(sock,host) local ok, err = sock:connect(host[1], host[2]) if not ok then return nil, "failed to connect to HTTP server " @@ -797,7 +753,7 @@ local function _http_connect(self,host) return sock end - +]]-- local function _http_status_receive(sock) local line, err, partial = sock:receive("*l") @@ -838,7 +794,7 @@ local function _http_header_receive(sock) if type(ret[key]) ~= "table" then ret[key] = { ret[key] } end - tbl_insert(ret[key], tostring(val)) + insert(ret[key], tostring(val)) else ret[key] = tostring(val) end @@ -851,31 +807,7 @@ end local function _http_header_send(sock, host, method, length, param) local hoststr - if (host[4] and host[2] ~= 443) or (not host[4] and host[2] ~= 80) then - hoststr = host[1]..":"..host[2] - else - hoststr = host[1] - end - - local query = { - true, - 'Host: '..hoststr, - 'User-Agent: '..agent, - 'Accept: application/dns-message', - 'Connection: keep-alive' - } - - if method == nil or method == ngx.HTTP_GET then - query[1] = 'GET '.. host[3]..param.. ' HTTP/1.1' - query = concat(query,"\r\n").."\r\n\r\n" - elseif method == ngx.HTTP_POST then - query[1] = 'POST '.. host[3] .. ' HTTP/1.1' - insert(query,'Content-Length: '..length) - insert(query,'Content-Type: application/dns-message') - query = concat(query, "\r\n").."\r\n\r\n" - else - return nil, "unsupported method" - end + -- HEADER local bytes, err = sock:send(query) @@ -904,15 +836,14 @@ local function _http_body_receive(sock, header) end -local function _http_query(self,host,opts) - local sock, err = _http_connect(self, host) - +local function _http_query(sock,host,opts) + + local sock, err = _http_connect(sock, host) if not sock then return nil, err end local bytes, err = _http_header_send(sock, host, opts.method, opts.body and #opts.body or 0, opts.param) - if not bytes then return nil, err end @@ -935,13 +866,11 @@ local function _http_query(self,host,opts) end local header, err = _http_header_receive(sock) - if not header then return nil, err end local data, err = _http_body_receive(sock, header) - if not data then return nil, err end @@ -955,50 +884,56 @@ local function _http_query(self,host,opts) } end -local function _doh_query(self, qname, opts, tries) - local retrans = self.retrans - if tries then - tries[1] = nil - end +local function _doh_query(qname, opts, tries, servers) + --local sock = self.tcp_sock + --if not sock then + -- return nil, "not initialized" + --end + local servers = self.servers - - if #servers == 0 then + if not servers:size() then return nil, "no servers available" end + local retrans = self.retrans + if tries then + tries[1] = nil + end + + local method = self.doh_method local err + --if method == ngx_post then + -- id = _gen_id(self) + -- opts = { + -- method = ngx_post, + -- body = table.concat(_build_request(qname, id, self.no_recurse, opts)) + -- } + --else + -- opts = { + -- method = ngx_get, + -- param = self.doh_encode and b64.encode_base64url(qname) or qname + -- } + --end + for i = 1, retrans do - local idx = i - - if idx > #servers then - idx = 1 - end - - local res - local id - - if self.doh_method == ngx.HTTP_GET then - res = _http_query(self,servers[idx], { method = ngx.HTTP_GET, param = b64.encode_base64url(qname) }) - else - id = _gen_id(self) - local bdata = table.concat(_build_request(qname, id, self.no_recurse, opts)) - res, err = _http_query(self,servers[idx],{ method = ngx.HTTP_POST, body = bdata }) - end - + local id, opts + local server = servers:pick() + + local res, err = _http_query(sock, server, opts) if not res then return nil, err, tries end - if res.status == 200 and res.body then + if res and res.status == 200 and res.body then local answers - if self.doh_method == ngx.HTTP_GET then + if method == ngx_get then local ident_hi = byte(res.body, 1) local ident_lo = byte(res.body, 2) id = lshift(ident_hi, 8) + ident_lo end - answers, err = parse_response(res.body, id, opts) + answers, err = _parse_response(res.body, id, opts) if answers then return answers, nil, tries end @@ -1006,7 +941,7 @@ local function _doh_query(self, qname, opts, tries) if err and err ~= "id mismatch" then break else - log(DEBUG,"doh query failed to parse response",err) + log(DEBUG,"DoH query failed to parse response",err) end end @@ -1019,14 +954,12 @@ local function _doh_query(self, qname, opts, tries) return nil, err, tries end - -local function _udp_tcp_query(self, qname, opts, tries) - local socks = self.socks - if not socks then +local function _udp_tcp_query(qname, opts, tries, servers) + if not servers then return nil, "not initialized" end - local id = _gen_id(self) + --local id = _gen_id(self) local query, err = _build_request(qname, id, self.no_recurse, opts) if not query then @@ -1044,12 +977,11 @@ local function _udp_tcp_query(self, qname, opts, tries) -- print("retrans: ", retrans) for i = 1, retrans do - local sock = pick_sock(self, socks) - - local ok - ok, err = sock:send(query) + local sock = servers:pick_sock() + + local ok, err = sock:send(query) if not ok then - local server = _get_cur_server(self) + local server = servers:current_server() err = "failed to send request to UDP server " .. concat(server, ":") .. ": " .. err @@ -1059,7 +991,7 @@ local function _udp_tcp_query(self, qname, opts, tries) for _ = 1, 128 do buf, err = sock:receive(4096) if err then - local server = _get_cur_server(self) + local server = servers:current_server() err = "failed to receive reply from UDP server " .. concat(server, ":") .. ": " .. err break @@ -1067,9 +999,9 @@ local function _udp_tcp_query(self, qname, opts, tries) if buf then local answers - answers, err = parse_response(buf, id, opts) + answers, err = _parse_response(buf, id, opts) if err == "truncated" then - answers, err = _tcp_query(self, query, id, opts) + answers, err = _tcp_query(sock, query, id, opts, servers) end if err and err ~= "id mismatch" then @@ -1094,12 +1026,14 @@ local function _udp_tcp_query(self, qname, opts, tries) end -function _M.compress_ipv6_addr(addr) +---------------------------[[ private functions ]]------------------------ + +local function _compress_ipv6_addr(addr) local addr = re_sub(addr, "^(0:)+|(:0)+$|:(0:)+", "::", "jo") if addr == "::0" then addr = "::" end - + return addr end @@ -1107,36 +1041,33 @@ end local function _expand_ipv6_addr(addr) if find(addr, "::", 1, true) then local ncol, addrlen = 8, #addr - + for i = 1, addrlen do if byte(addr, i) == COLON_CHAR then ncol = ncol - 1 end end - + if byte(addr, 1) == COLON_CHAR then addr = "0" .. addr end - + if byte(addr, -1) == COLON_CHAR then addr = addr .. "0" end - + addr = re_sub(addr, "::", ":" .. rep("0:", ncol), "jo") end - + return addr end -_M.expand_ipv6_addr = _expand_ipv6_addr - - -function _M.arpa_str(addr) +local function _arpa_str(addr) if find(addr, ":", 1, true) then addr = _expand_ipv6_addr(addr) local idx, hidx, addrlen = 1, 1, #addr - + for i = addrlen, 0, -1 do local s = byte(addr, i) if s == COLON_CHAR or not s then @@ -1151,148 +1082,141 @@ function _M.arpa_str(addr) hidx = hidx + 1 end end - + addr = char(unpack(arpa_tmpl)) else addr = re_sub(addr, [[(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})]], "$4.$3.$2.$1.in-addr.arpa", "ajo") end - + return addr end -function _M.reverse_query(self, addr) - return self.query(self, self.arpa_str(addr), - {qtype = self.TYPE_PTR}) -end - +------------------[[ public instance methods ]]--------------- -local function _new_doh(class,opts) - local method - if opts.doh_method == 'POST' then - method = ngx.HTTP_POST - elseif opts.doh_method == 'GET' then - method = ngx.HTTP_GET - else - return nil, nil, "invalid DoH mode specified" +local function _query(self, qname, opts, tries) + log(ngx.ERR,"QUERY") + + local servers = self.servers + if not servers then + return nil, "not initialized" end - - local servers = opts.nameservers - local n = #servers - - for i = 1, n do - local captures, err = re_match(servers[i],"^((https?)(://))?([A-Za-z0-9\\.-]+)(:[1-9][0-9]*)?(/.+)$") - - if not captures then - return nil, nil, err - end - - local host = captures[4] - local ssl = (captures[1] == 'https://') and true or false - local port - - if captures[5] then - port = tonumber(sub(captures[5],2)) - elseif not ssl then - port = 80 + + local retrans = self.retrans + -- print("retrans: ", retrans) + if tries then + tries[1] = nil + end + + local id = _gen_id(self) + local query, err = _build_wire_request(qname, id, self.no_recurse, opts) + if not query then + return nil, err + end + + -- local cjson = require "cjson" + -- print("query: ", cjson.encode(concat(query, ""))) + + + + + + for i = 1, retrans do + local sock = servers:pick() + + --[[ Abstract send ]] + local ok, err = sock:send(query) + if not ok then + local server = servers:current_server() + err = "failed to send request to UDP server " + .. concat(server, ":") .. ": " .. err + --[[ End Of Send ]] else - port = 443 + + local buf + for _ = 1, 128 do + --[[ Receive ]] + buf, err = sock:receive(4096) + if err then + local server = servers:current_server() + err = "failed to receive reply from UDP server " + .. concat(server, ":") .. ": " .. err + break + end + --[[ End Of Receive]] + + --[[ Parse ]] + if buf then + local answers + answers, err = _parse_response(buf, id, opts) + if err == "truncated" then + answers, err = _tcp_query(sock, query, id, opts, servers) + end + + if err and err ~= "id mismatch" then + break + end + + if answers then + return answers, nil, tries + end + end + --[[ End Of Parse ]] + -- only here in case of an "id mismatch" + end + --[[ END ]] end - - if not port then - return nil, nil, "invalid port specified" + + if tries then + tries[i] = err + tries[i + 1] = nil -- ensure termination for user supplied table end - - servers[i] = { host, port, captures[6], ssl} end - - _M.query = _doh_query - - return servers, method + + return nil, err, tries end -local function _new_tcp_udp(class,opts,timeout) - local servers = opts.nameservers - local n = #servers - local socks = {} - - for i = 1, n do - local server = servers[i] - local host, port, ssl - - if type(server) == 'table' then - host = server[1] - port = server[2] or 53 - else - host = server - port = 53 - servers[i] = {host, port} - end - - local sock, err = udp() - if not sock then - return nil, "failed to create udp socket: " .. err - end - - local ok, err = sock:setpeername(host, port) - if not ok then - return nil, "failed to set peer name: " .. err - end - - sock:settimeout(timeout) - insert(socks, sock) - end +local function _reverse_query(class, addr) + log(ngx.ERR,"REVERSE QUERY") + return query(class, arpa_str(addr), + {qtype = class.TYPE_PTR}) +end - _M.query = _udp_tcp_query - return servers,socks -end +local resolver_mt = { + udp_query = _udp_query, + udp_tcp_query = _udp_tcp_query, + tcp_query = _tcp_query, + dot_query = _dot_query, + doh_query = _doh_query, + query = _query, + reverse_query = _reverse_query +} +----------------------------------------------------------------------------- function _M.new(class, opts) if not opts then return nil, "no options table specified" end - local servers = opts.nameservers - if not servers or #servers == 0 then + local nameservers = opts.nameservers + if not nameservers or #nameservers == 0 then return nil, "no nameservers specified" end - local timeout = opts.timeout or 2000 -- default 2 sec - local servers, socks, err - local method - - if opts.doh then - servers, method, err = _new_doh(class,opts) - else - servers, socks, err = _new_tcp_udp(class,opts,timeout) - end - + local servers, err = _servers_array_new(opts) if not servers then - return nil, err + return nil, err end - - local tcp_sock, err = tcp() - if not tcp_sock then - return nil, "failed to create tcp socket: " .. err - end - - tcp_sock:settimeout(timeout) - return setmetatable( - { cur = opts.no_random and 1 or rand(1, #servers), - socks = socks, - tcp_sock = tcp_sock, - servers = servers, - retrans = opts.retrans or 5, - no_recurse = opts.no_recurse, - doh = opts.doh, - doh_method = method - }, mt) + return setmetatable({ + servers = servers, + retrans = opts.retrans or 5, + no_recurse = opts.no_recurse + }, { __index = resolver_mt }) end - return _M diff --git a/lib/resty/dns/wireformat.lua b/lib/resty/dns/wireformat.lua new file mode 100644 index 0000000..efcda00 --- /dev/null +++ b/lib/resty/dns/wireformat.lua @@ -0,0 +1,596 @@ +local bit = require "bit" +local band = bit.band +local rshift = bit.rshift +local lshift = bit.lshift +local insert = table.insert +local concat = table.concat +local byte = string.byte +local char= string.char +local byte = string.byte +local sub = string.sub +local gsub = string.gsub + +local log = ngx.log +local DEBUG = ngx.DEBUG + +local DOT_CHAR = byte(".") +local ZERO_CHAR = byte("0") + +local TYPE_A = 1 +local TYPE_NS = 2 +local TYPE_CNAME = 5 +local TYPE_SOA = 6 +local TYPE_PTR = 12 +local TYPE_MX = 15 +local TYPE_TXT = 16 +local TYPE_AAAA = 28 +local TYPE_SRV = 33 +local TYPE_SPF = 99 + +local CLASS_IN = 1 + +local SECTION_AN = 1 +local SECTION_NS = 2 +local SECTION_AR = 3 + +local soa_int32_fields = { "serial", "refresh", "retry", "expire", "minimum" } + +local resolver_errstrs = { + "format error", -- 1 + "server failure", -- 2 + "name error", -- 3 + "not implemented", -- 4 + "refused", -- 5 +} + +local function _encode_name(s) + return char(#s) .. s +end + + +local function _decode_name(buf, pos) + local labels = {} + local nptrs = 0 + local p = pos + while nptrs < 128 do + local fst = byte(buf, p) + + if not fst then + return nil, 'truncated'; + end + + -- print("fst at ", p, ": ", fst) + + if fst == 0 then + if nptrs == 0 then + pos = pos + 1 + end + break + end + + if band(fst, 0xc0) ~= 0 then + -- being a pointer + if nptrs == 0 then + pos = pos + 2 + end + + nptrs = nptrs + 1 + + local snd = byte(buf, p + 1) + if not snd then + return nil, 'truncated' + end + + p = lshift(band(fst, 0x3f), 8) + snd + 1 + + -- print("resolving ptr ", p, ": ", byte(buf, p)) + + else + -- being a label + local label = sub(buf, p + 1, p + fst) + insert(labels, label) + + -- print("resolved label ", label) + + p = p + fst + 1 + + if nptrs == 0 then + pos = p + end + end + end + + return concat(labels, "."), pos +end + + +local function _parse_wire_section(answers, section, buf, start_pos, size, + should_skip) + local pos = start_pos + + for _ = 1, size do + -- print(format("ans %d: qtype:%d qclass:%d", i, qtype, qclass)) + local ans = {} + + if not should_skip then + insert(answers, ans) + end + + ans.section = section + + local name + name, pos = _decode_name(buf, pos) + if not name then + return nil, pos + end + + ans.name = name + + -- print("name: ", name) + + local type_hi = byte(buf, pos) + local type_lo = byte(buf, pos + 1) + local typ = lshift(type_hi, 8) + type_lo + + ans.type = typ + + -- print("type: ", typ) + + local class_hi = byte(buf, pos + 2) + local class_lo = byte(buf, pos + 3) + local class = lshift(class_hi, 8) + class_lo + + ans.class = class + + -- print("class: ", class) + + local byte_1, byte_2, byte_3, byte_4 = byte(buf, pos + 4, pos + 7) + + local ttl = lshift(byte_1, 24) + lshift(byte_2, 16) + + lshift(byte_3, 8) + byte_4 + + -- print("ttl: ", ttl) + + ans.ttl = ttl + + local len_hi = byte(buf, pos + 8) + local len_lo = byte(buf, pos + 9) + local len = lshift(len_hi, 8) + len_lo + + -- print("record len: ", len) + + pos = pos + 10 + + if typ == TYPE_A then + + if len ~= 4 then + return nil, "bad A record value length: " .. len + end + + local addr_bytes = { byte(buf, pos, pos + 3) } + local addr = concat(addr_bytes, ".") + -- print("ipv4 address: ", addr) + + ans.address = addr + + pos = pos + 4 + + elseif typ == TYPE_CNAME then + + local cname, p = _decode_name(buf, pos) + if not cname then + return nil, pos + end + + if p - pos ~= len then + return nil, format("bad cname record length: %d ~= %d", + p - pos, len) + end + + pos = p + + -- print("cname: ", cname) + + ans.cname = cname + + elseif typ == TYPE_AAAA then + + if len ~= 16 then + return nil, "bad AAAA record value length: " .. len + end + + local addr_bytes = { byte(buf, pos, pos + 15) } + local flds = {} + for i = 1, 16, 2 do + local a = addr_bytes[i] + local b = addr_bytes[i + 1] + if a == 0 then + insert(flds, format("%x", b)) + + else + insert(flds, format("%x%02x", a, b)) + end + end + + -- we do not compress the IPv6 addresses by default + -- due to performance considerations + + ans.address = concat(flds, ":") + + pos = pos + 16 + + elseif typ == TYPE_MX then + + -- print("len = ", len) + + if len < 3 then + return nil, "bad MX record value length: " .. len + end + + local pref_hi = byte(buf, pos) + local pref_lo = byte(buf, pos + 1) + + ans.preference = lshift(pref_hi, 8) + pref_lo + + local host, p = _decode_name(buf, pos + 2) + if not host then + return nil, pos + end + + if p - pos ~= len then + return nil, format("bad cname record length: %d ~= %d", + p - pos, len) + end + + ans.exchange = host + + pos = p + + elseif typ == TYPE_SRV then + if len < 7 then + return nil, "bad SRV record value length: " .. len + end + + local prio_hi = byte(buf, pos) + local prio_lo = byte(buf, pos + 1) + ans.priority = lshift(prio_hi, 8) + prio_lo + + local weight_hi = byte(buf, pos + 2) + local weight_lo = byte(buf, pos + 3) + ans.weight = lshift(weight_hi, 8) + weight_lo + + local port_hi = byte(buf, pos + 4) + local port_lo = byte(buf, pos + 5) + ans.port = lshift(port_hi, 8) + port_lo + + local name, p = _decode_name(buf, pos + 6) + if not name then + return nil, pos + end + + if p - pos ~= len then + return nil, format("bad srv record length: %d ~= %d", + p - pos, len) + end + + ans.target = name + + pos = p + + elseif typ == TYPE_NS then + + local name, p = _decode_name(buf, pos) + if not name then + return nil, pos + end + + if p - pos ~= len then + return nil, format("bad cname record length: %d ~= %d", + p - pos, len) + end + + pos = p + + -- print("name: ", name) + + ans.nsdname = name + + elseif typ == TYPE_TXT or typ == TYPE_SPF then + + local key = (typ == TYPE_TXT) and "txt" or "spf" + + local slen = byte(buf, pos) + if slen + 1 > len then + -- truncate the over-run TXT record data + slen = len + end + + -- print("slen: ", len) + + local val = sub(buf, pos + 1, pos + slen) + local last = pos + len + pos = pos + slen + 1 + + if pos < last then + -- more strings to be processed + -- this code path is usually cold, so we do not + -- merge the following loop on this code path + -- with the processing logic above. + + val = {val} + local idx = 2 + repeat + local slen = byte(buf, pos) + if pos + slen + 1 > last then + -- truncate the over-run TXT record data + slen = last - pos - 1 + end + + val[idx] = sub(buf, pos + 1, pos + slen) + idx = idx + 1 + pos = pos + slen + 1 + + until pos >= last + end + + ans[key] = val + + elseif typ == TYPE_PTR then + + local name, p = _decode_name(buf, pos) + if not name then + return nil, pos + end + + if p - pos ~= len then + return nil, format("bad cname record length: %d ~= %d", + p - pos, len) + end + + pos = p + + -- print("name: ", name) + + ans.ptrdname = name + + elseif typ == TYPE_SOA then + local name, p = _decode_name(buf, pos) + if not name then + return nil, pos + end + ans.mname = name + + pos = p + name, p = _decode_name(buf, pos) + if not name then + return nil, pos + end + ans.rname = name + + for _, field in ipairs(soa_int32_fields) do + local byte_1, byte_2, byte_3, byte_4 = byte(buf, p, p + 3) + ans[field] = lshift(byte_1, 24) + lshift(byte_2, 16) + + lshift(byte_3, 8) + byte_4 + p = p + 4 + end + + pos = p + + else + -- for unknown types, just forward the raw value + + ans.rdata = sub(buf, pos, pos + len - 1) + pos = pos + len + end + end + + return pos +end + + +local function _parse_wire_response(buf, id, opts) + local n = #buf + if n < 12 then + return nil, 'truncated' + end + + -- header layout: ident flags nqs nan nns nar + + local ident_hi = byte(buf, 1) + local ident_lo = byte(buf, 2) + local ans_id = lshift(ident_hi, 8) + ident_lo + + -- print("id: ", id, ", ans id: ", ans_id) + + if ans_id ~= id then + -- identifier mismatch and throw it away + log(DEBUG, "id mismatch in the DNS reply: ", ans_id, " ~= ", id) + return nil, "id mismatch" + end + + local flags_hi = byte(buf, 3) + local flags_lo = byte(buf, 4) + local flags = lshift(flags_hi, 8) + flags_lo + + -- print(format("flags: 0x%x", flags)) + + if band(flags, 0x8000) == 0 then + return nil, format("bad QR flag in the DNS response") + end + + if band(flags, 0x200) ~= 0 then + return nil, "truncated" + end + + local code = band(flags, 0xf) + + -- print(format("code: %d", code)) + + local nqs_hi = byte(buf, 5) + local nqs_lo = byte(buf, 6) + local nqs = lshift(nqs_hi, 8) + nqs_lo + + -- print("nqs: ", nqs) + + if nqs ~= 1 then + return nil, format("bad number of questions in DNS response: %d", nqs) + end + + local nan_hi = byte(buf, 7) + local nan_lo = byte(buf, 8) + local nan = lshift(nan_hi, 8) + nan_lo + + -- print("nan: ", nan) + + local nns_hi = byte(buf, 9) + local nns_lo = byte(buf, 10) + local nns = lshift(nns_hi, 8) + nns_lo + + local nar_hi = byte(buf, 11) + local nar_lo = byte(buf, 12) + local nar = lshift(nar_hi, 8) + nar_lo + + -- skip the question part + + local ans_qname, pos = _decode_name(buf, 13) + if not ans_qname then + return nil, pos + end + + -- print("qname in reply: ", ans_qname) + + -- print("question: ", sub(buf, 13, pos)) + + if pos + 3 + nan * 12 > n then + -- print(format("%d > %d", pos + 3 + nan * 12, n)) + return nil, 'truncated' + end + + -- question section layout: qname qtype(2) qclass(2) + + --[[ + local type_hi = byte(buf, pos) + local type_lo = byte(buf, pos + 1) + local ans_type = lshift(type_hi, 8) + type_lo + ]] + + -- print("ans qtype: ", ans_type) + + local class_hi = byte(buf, pos + 2) + local class_lo = byte(buf, pos + 3) + local qclass = lshift(class_hi, 8) + class_lo + + -- print("ans qclass: ", qclass) + + if qclass ~= 1 then + return nil, format("unknown query class %d in DNS response", qclass) + end + + pos = pos + 4 + + local answers = {} + + if code ~= 0 then + answers.errcode = code + answers.errstr = resolver_errstrs[code] or "unknown" + end + + local authority_section, additional_section + + if opts then + authority_section = opts.authority_section + additional_section = opts.additional_section + if opts.qtype == TYPE_SOA then + authority_section = true + end + end + + local err + + pos, err = _parse_wire_section(answers, SECTION_AN, buf, pos, nan) + + if not pos then + return nil, err + end + + if not authority_section and not additional_section then + return answers + end + + pos, err = _parse_wire_section(answers, SECTION_NS, buf, pos, nns, + not authority_section) + + if not pos then + return nil, err + end + + if not additional_section then + return answers + end + + pos, err = _parse_wire_section(answers, SECTION_AR, buf, pos, nar) + + if not pos then + return nil, err + end + + return answers +end + +local function _build_wire_request(qname, id, no_recurse, opts) + local qtype = opts and opts.qtype or 1 + local ident_hi = char(rshift(id, 8)) + local ident_lo = char(band(id, 0xff)) + + local flags + if no_recurse then + -- print("found no recurse") + flags = "\0\0" + else + flags = "\1\0" + end + + local nqs = "\0\1" + local nan = "\0\0" + local nns = "\0\0" + local nar = "\0\0" + local typ = char(rshift(qtype, 8), band(qtype, 0xff)) + local class = "\0\1" -- the Internet class + + if byte(qname, 1) == DOT_CHAR then + return nil, "bad name" + end + + local name = gsub(qname, "([^.]+)%.?", _encode_name) .. '\0' + + return { + ident_hi, ident_lo, flags, nqs, nan, nns, nar, + name, typ, class + } +end + + +return { + TYPE = { + A = TYPE_A, + NS = TYPE_NS, + CNAME = TYPE_CNAME, + SOA = TYPE_SOA, + PTR = TYPE_PTR, + MX = TYPE_MX, + TXT = TYPE_TXT, + AAAA = TYPE_AAAA, + SRV = TYPE_SRV, + SPF = TYPE_SPF + }, + CLASS = { + IN = CLASS_IN + }, + SECTION = { + AN = SECTION_AN, + NS = SECTION_NS, + AR = SECTION_AR + }, + build_request = _build_wire_request, + parse_response = _parse_wire_response +} diff --git a/t/mock.t b/t/mock.t index 43c186f..92a5d05 100644 --- a/t/mock.t +++ b/t/mock.t @@ -1994,3 +1994,97 @@ failed to query: 3: failed to receive reply from UDP server 127.0.0.1:20002: connection refused --- error_log Connection refused + + + +=== TEST 40: single answer DoH GET request, good A answer +--- http_config eval: $::HttpConfig +--- config + location /t { + content_by_lua ' + local resolver = require "resty.dns.resolver" + + local r, err = resolver:new{ + nameservers = { "https://cloudflare-dns.com/dns-query?name=" }, + doh = true, + doh_method = 'GET' + } + if not r then + ngx.say("failed to instantiate resolver: ", err) + return + end + + r._id = 125 + + local ans, err = r:query("www.google.com", { qtype = r.TYPE_A }) + if not ans then + ngx.say("failed to query: ", err) + return + end + + local ljson = require "ljson" + ngx.say("records: ", ljson.encode(ans)) + '; + } +--- doh_reply dns +{ + id => 125, + opcode => 0, + qname => 'www.google.com', + answer => [{ name => "www.google.com", ipv4 => "127.0.0.1", ttl => 123456 }], +} +--- request +GET /t +--- doh_query eval +"\x{00}}\x{01}\x{00}\x{00}\x{01}\x{00}\x{00}\x{00}\x{00}\x{00}\x{00}\x{03}www\x{06}google\x{03}com\x{00}\x{00}\x{01}\x{00}\x{01}" +--- response_body +records: [{"address":"127.0.0.1","class":1,"name":"www.google.com","section":1,"ttl":123456,"type":1}] +--- no_error_log +[error] + + + +=== TEST 41: single answer DoH POST reply, good A answer +--- http_config eval: $::HttpConfig +--- config + location /t { + content_by_lua ' + local resolver = require "resty.dns.resolver" + + local r, err = resolver:new{ + nameservers = { "https://cloudflare-dns.com/dns-query" }, + doh = true, + doh_method = 'POST' + } + if not r then + ngx.say("failed to instantiate resolver: ", err) + return + end + + r._id = 125 + + local ans, err = r:query("www.google.com", { qtype = r.TYPE_A }) + if not ans then + ngx.say("failed to query: ", err) + return + end + + local ljson = require "ljson" + ngx.say("records: ", ljson.encode(ans)) + '; + } +--- doh_reply dns +{ + id => 125, + opcode => 0, + qname => 'www.google.com', + answer => [{ name => "www.google.com", ipv4 => "127.0.0.1", ttl => 123456 }], +} +--- request +GET /t +--- doh_query eval +"\x{00}}\x{01}\x{00}\x{00}\x{01}\x{00}\x{00}\x{00}\x{00}\x{00}\x{00}\x{03}www\x{06}google\x{03}com\x{00}\x{00}\x{01}\x{00}\x{01}" +--- response_body +records: [{"address":"127.0.0.1","class":1,"name":"www.google.com","section":1,"ttl":123456,"type":1}] +--- no_error_log +[error]