From 71f64a169fa36172a7af4fa6e9b323df0d27159a Mon Sep 17 00:00:00 2001 From: JasonCheng1 Date: Wed, 5 May 2021 17:13:09 -0400 Subject: [PATCH 1/3] this week's commit got it working pretty --- my_server.py | 188 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 130 insertions(+), 58 deletions(-) diff --git a/my_server.py b/my_server.py index 885164c..4d75f2c 100644 --- a/my_server.py +++ b/my_server.py @@ -5,13 +5,13 @@ import struct import argparse from sys import argv -from time import sleep from helper_funcs import DNSQuery from datetime import datetime, timedelta from collections import defaultdict, OrderedDict -import copy import random import ipaddress +import itertools +import select MAX_LEVEL = 10 # max number of iterative queries we make before we stop CACHE_SIZE = 100000 @@ -20,9 +20,12 @@ A_TYPE = 1 NS_TYPE = 2 CNAME_TYPE = 5 +SOA_TYPE = 6 NOERROR_RCODE = 0 +SERVFAIL_RCODE = 2 NAMEERROR_RCODE = 3 +REFUSED_RCODE = 5 class LRUCache: @@ -55,6 +58,7 @@ def __init__(self, port): def check_Cache(self, key, now): a = self.cache.get(key) + if a: a = [record for record in a if record["expire_time"] > now] # Get rid of expired records self.cache.put(key, a if a else None) @@ -74,13 +78,28 @@ def query_then_cache(self, name_server, q): ### Query sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("", 8000)) + sock.bind(("", 12000)) name_server = str(ipaddress.ip_address(name_server)) + sock.connect((name_server, 53)) q.header["RD"] = 0 # we do not want recursive query + q.header["QR"] = 0 + q.header["ANCOUNT"] = 0 + q.header["NSCOUNT"] = 0 + q.header["NSCOUNT"] = 0 + q.header["ARCOUNT"] = 0 + q.answers = [] sock.sendall(q.to_bytes()) - answer = sock.recv(512) + if not __debug__: + print("Q: ", q) + + read, _, _ = select.select([sock], [], [], 10) # Timeout after 10 sec + if read: + answer, _ = sock.recvfrom(512) + else: + return q + sock.close() a = DNSQuery(answer) @@ -90,7 +109,7 @@ def query_then_cache(self, name_server, q): new_rr = defaultdict(list) for record in a.answers: key = ( - record["NAME"].decode("utf-8"), + record["NAME"].decode("utf-8").lower(), record["TYPE"], record["CLASS"], ) @@ -101,9 +120,12 @@ def query_then_cache(self, name_server, q): new_rr[key].append(val) for key, val in new_rr.items(): self.cache.put(key, val) - print("%%") - # print([[str(num) for num in record["RDATA"]] for record in a.answers]) - print(a) + + # TODO Support Negative Caching + # else a.header["RCODE"] == NXDOMAIN ... + + if not __debug__: + print("A: ", a) return a @@ -111,16 +133,14 @@ def get_dns_response(self, query): # input: A query and any state in self # returns: the correct response to the query obtained by asking DNS name servers # Your code goes here, when you change any 'self' variables make sure to use a lock - print("***") - print(query) - print("&&&") q = DNSQuery(query) - print(q) + if not __debug__: + print("Query: ", q) ### Reject EDNS Reference: https://tools.ietf.org/html/rfc6891#section-6 if q.header["ARCOUNT"] and any(rec["TYPE"] == 41 for rec in q.answers): q.header["QR"] = 1 - q.header["RCODE"] = 4 + q.header["RCODE"] = REFUSED_RCODE q.header["ANCOUNT"] = 0 q.header["NSCOUNT"] = 0 q.header["ARCOUNT"] = 0 @@ -128,18 +148,31 @@ def get_dns_response(self, query): return q.to_bytes() sname, stype, sclass = ( - q.question["NAME"].decode("utf-8"), + q.question["NAME"].decode("utf-8").lower(), q.question["QTYPE"], q.question["QCLASS"], ) - - return self.recursive_lookup(q, sname, stype, sclass).to_bytes() - - def recursive_lookup(self, q, sname, stype, sclass, now=datetime.now(), limit=120): + # TODO Handle ANY Query + """ + if stype == ANY_TYPE: + for record_type in [A, AAAA, ....]: + check if (sname, record_type, sclass) in cache: + append rec to q.answer + q.header["ANCOUNT] = len(q.answer) + return q + """ + + return self.recursive_lookup( + q, sname, stype, sclass, datetime.now() + ).to_bytes() # NOTE datetime.now() was not updating until I added in as a parameter + + def recursive_lookup(self, q, sname, stype, sclass, now=datetime.now(), limit=60): if (datetime.now() - now) > timedelta(seconds=limit): # look up took longer than 100s - print("Query Timed Out") + if not __debug__: + print("Query Timed Out [1]", q, datetime.now(), now) q.header["QR"] = 1 - q.header["RCODE"] = 2 + q.header["RA"] = 1 + q.header["RCODE"] = SERVFAIL_RCODE q.header["ANCOUNT"] = 0 q.header["NSCOUNT"] = 0 q.header["ARCOUNT"] = 0 @@ -153,49 +186,80 @@ def recursive_lookup(self, q, sname, stype, sclass, now=datetime.now(), limit=12 q.header["QR"] = 1 q.header["ANCOUNT"] = len(a) q.header["AA"] = 0 # Because result is from a cache + q.header["RA"] = 1 for record in a: - record["resp"]["TTL"] = int((record["expire_time"] - datetime.now()).total_seconds()) + record["resp"]["TTL"] = int( + (record["expire_time"] - datetime.now()).total_seconds() + ) # UPDATE the TTL of records q.answers = [rec["resp"] for rec in a] return q while True: + if (datetime.now() - now) > timedelta(seconds=limit): # look up took longer than 100s + if not __debug__: + print("Query Timed Out [2]", q, datetime.now(), now) + q.header["QR"] = 1 + q.header["RA"] = 1 + q.header["RCODE"] = SERVFAIL_RCODE + q.header["ANCOUNT"] = 0 + q.header["NSCOUNT"] = 0 + q.header["ARCOUNT"] = 0 + q.answers = [] + return q + ### STEP 2 ### sbelt = [ - "198.41.0.4", - "199.9.14.201", - "195.129.12.83", - ] # A and B root server - slist = [] + "198.41.0.4", # root A + "199.9.14.201", # root B + "195.129.12.83", # dutch + # "1.1.1.1", # cloudflare + ] + mat_slist = defaultdict(list) split_sname = sname.split(".") for i in range(len(split_sname)): + slist = [] reduced_sname = ".".join(split_sname[i:]) - reduced_key = (reduced_sname, NS_TYPE, sclass) - ns_records = self.check_Cache(reduced_key, now) - if ns_records: - for ns in ns_records: - ns_name = ns["RDATA"][0] - reduced_key = (ns_name.decode("utf-8"), A_TYPE, sclass) - if a := self.check_Cache(reduced_key, now): # append the ip of name servers - random_a = random.choice(a) - slist.extend(random_a["RDATA"]) - if ns_records and not slist: - # TODO Kick off parallel process to look for the ip addresses of NS server - # parallel_thread = threading.Thread(target=self.recursive_lookup, args=(q, *reduced_key, now, 30)) - # parallel_thread.start() - - ### SINGLE THREAD For now - original_name = q.question["NAME"] - for ns in ns_records: - ns_name = ns["RDATA"][0] - q.question["NAME"] = ns_name - reduced_key = (ns_name.decode("utf-8"), A_TYPE, sclass) - _ = self.recursive_lookup(q, *reduced_key, now, 60) - a = self.check_Cache(reduced_key, now) - if a: - slist.extend(random.choice(a)["RDATA"]) - q.question["NAME"] = original_name - - slist.extend(sbelt) + if not reduced_sname: + continue + reduced_key = (reduced_sname, A_TYPE, sclass) + if a := self.check_Cache(reduced_key, now): # append the ip of name servers + slist.extend([rec["RDATA"][0] for rec in a]) + else: + reduced_key = (reduced_sname, NS_TYPE, sclass) + ns_records = self.check_Cache(reduced_key, now) + if ns_records: + for ns in ns_records: + ns_name = ns["RDATA"][0] + reduced_key = (ns_name.decode("utf-8").lower(), A_TYPE, sclass) + if a := self.check_Cache(reduced_key, now): + slist.extend([rec["RDATA"][0] for rec in a]) + if not slist: + # TODO Kick off parallel process to look for the ip addresses of NS server + # parallel_thread = threading.Thread(target=self.recursive_lookup, args=(q, *reduced_key, now, 30)) + # parallel_thread.start() + + ### SINGLE THREAD for now + original_name = q.question["NAME"] + for ns in ns_records: + ns_name = ns["RDATA"][0] + q.question["NAME"] = ns_name + reduced_key = (ns_name.decode("utf-8").lower(), A_TYPE, sclass) + _ = self.recursive_lookup(q, *reduced_key, now) + if a := self.check_Cache(reduced_key, now): + slist.extend([rec["RDATA"][0] for rec in a]) + + q.question["NAME"] = original_name + + mat_slist[reduced_sname] = slist + + mat_slist["."] = sbelt + if not __debug__: + for k, value in mat_slist.items(): + if k == ".": + continue + print(k, ":", [socket.inet_ntoa(v) for v in value]) + + slist = list(itertools.chain.from_iterable(mat_slist.values())) # flatten mat_slist ### STEP 3 ### for ns in slist: @@ -205,13 +269,21 @@ def recursive_lookup(self, q, sname, stype, sclass, now=datetime.now(), limit=12 gotAns = self.check_Cache(key, now) ### STEP 4.1 ### - if (gotAns and a.header["RCODE"] == NOERROR_RCODE) or a.header[ - "RCODE" - ] == NAMEERROR_RCODE: # We found it or there's an ans + if ( + (gotAns and a.header["RCODE"] == NOERROR_RCODE) # we found an Answer + or a.header["RCODE"] == NAMEERROR_RCODE # there was a name error + or ( + a.answers and a.header["AA"] == 1 and [rec for rec in a.answers if rec["TYPE"] == SOA_TYPE] + ) # SOA response + or ( + q.header["ANCOUNT"] and [rec for rec in a.answers if rec["TYPE"] == A_TYPE] + ) # found an ans but name does not match quetion name + ): + a.header["RA"] = 1 return a ### STEP 4.2 ### - elif a.answers and (ns_records := [rec for rec in a.answers if rec["TYPE"] == NS_TYPE]): # NSNAME + elif a.answers and ([rec for rec in a.answers if rec["TYPE"] == NS_TYPE]): # NSNAME break ### STEP 4.3 ### @@ -219,7 +291,7 @@ def recursive_lookup(self, q, sname, stype, sclass, now=datetime.now(), limit=12 new_sname = random.choice(cname_rec)["RDATA"][0] original_name = q.question["NAME"] q.question["NAME"] = new_sname - resp = self.recursive_lookup(q, new_sname.decode("utf-8"), stype, sclass, now) + resp = self.recursive_lookup(q, new_sname.decode("utf-8").lower(), stype, sclass, now) resp.question["NAME"] = original_name return resp From 2364d7be160a55722eba7c8a4bd4df32a5077fd4 Mon Sep 17 00:00:00 2001 From: JasonCheng1 Date: Fri, 7 May 2021 04:38:07 -0400 Subject: [PATCH 2/3] late night commit --- my_server.py | 244 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 142 insertions(+), 102 deletions(-) diff --git a/my_server.py b/my_server.py index 4d75f2c..27993ca 100644 --- a/my_server.py +++ b/my_server.py @@ -10,6 +10,7 @@ from collections import defaultdict, OrderedDict import random import ipaddress +import copy import itertools import select @@ -21,7 +22,16 @@ NS_TYPE = 2 CNAME_TYPE = 5 SOA_TYPE = 6 - +WKS_TYPE = 11 +PTR_TYPE = 12 +HINFO_TYPE = 13 +MINFO_TYPE = 14 +MX_TYPE = 15 +TXT_TYPE = 16 +ANY_TYPE = 255 +# ANY_TYPE = [A_TYPE, NS_TYPE, CNAME_TYPE, SOA_TYPE, WKS_TYPE, PTR_TYPE, HINFO_TYPE, MINFO_TYPE, MX_TYPE, TXT_TYPE] + +# RCODE Values NOERROR_RCODE = 0 SERVFAIL_RCODE = 2 NAMEERROR_RCODE = 3 @@ -57,6 +67,9 @@ def __init__(self, port): self.cache = LRUCache(CACHE_SIZE) def check_Cache(self, key, now): + # if key[1] == ANY_TYPE: + # return None + a = self.cache.get(key) if a: @@ -74,14 +87,42 @@ def check_Cache_ret_time(self, key, now): return a return None - def query_then_cache(self, name_server, q): + def check_timeout(self, q, now, limit): + if (datetime.now() - now) > timedelta(seconds=limit): + if not __debug__: + print("Query Timed Out [1]", q, datetime.now(), now) + q.header["QR"] = 1 + q.header["RCODE"] = SERVFAIL_RCODE + q.header["ANCOUNT"] = 0 + q.header["NSCOUNT"] = 0 + q.header["ARCOUNT"] = 0 + q.answers = [] + return q + return None + + def query_then_cache(self, destination_address_1, q): ### Query + ### https://tools.ietf.org/html/rfc5452#section-9 + + destination_port_1 = 53 + + old_id = copy.copy(q.header["ID"]) + + q.header["ID"] = random.randint(0, 65535) + + q_id_1, q_name_1, q_type_1, q_class_1 = ( + q.header["ID"], + q.question["NAME"], + q.question["QTYPE"], + q.question["QCLASS"], + ) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("", 12000)) - name_server = str(ipaddress.ip_address(name_server)) - sock.connect((name_server, 53)) + sock.bind(("", 0)) + destination_address_1 = str(ipaddress.ip_address(destination_address_1)) + + sock.connect((destination_address_1, destination_port_1)) q.header["RD"] = 0 # we do not want recursive query q.header["QR"] = 0 @@ -92,19 +133,35 @@ def query_then_cache(self, name_server, q): q.answers = [] sock.sendall(q.to_bytes()) if not __debug__: - print("Q: ", q) + print(f"Q {destination_address_1}: ", q) - read, _, _ = select.select([sock], [], [], 10) # Timeout after 10 sec + read, _, _ = select.select([sock], [], [], 15) # Timeout after 15 sec if read: - answer, _ = sock.recvfrom(512) + answer, (destination_address_2, destination_port_2) = sock.recvfrom(1024) else: + q.header["ID"] = old_id return q + if destination_address_1 != destination_address_2 or destination_port_1 != destination_port_2: + q.header["ID"] = old_id + return q # address/port do not match + sock.close() a = DNSQuery(answer) - ### Caching + q_id_2, q_name_2, q_type_2, q_class_2 = ( + a.header["ID"], + a.question["NAME"], + a.question["QTYPE"], + a.question["QCLASS"], + ) + + if q_id_1 != q_id_2 or q_name_1 != q_name_2 or q_type_1 != q_type_2 or q_class_1 != q_class_2: + q.header["ID"] = old_id + return q # id/name/type/class do not match + + ### Caching if a.header["RCODE"] == NOERROR_RCODE: new_rr = defaultdict(list) for record in a.answers: @@ -127,6 +184,8 @@ def query_then_cache(self, name_server, q): if not __debug__: print("A: ", a) + q.header["ID"] = old_id + a.header["ID"] = old_id return a def get_dns_response(self, query): @@ -140,45 +199,34 @@ def get_dns_response(self, query): ### Reject EDNS Reference: https://tools.ietf.org/html/rfc6891#section-6 if q.header["ARCOUNT"] and any(rec["TYPE"] == 41 for rec in q.answers): q.header["QR"] = 1 + q.header["RD"] = 0 + q.header["RA"] = 0 q.header["RCODE"] = REFUSED_RCODE q.header["ANCOUNT"] = 0 q.header["NSCOUNT"] = 0 q.header["ARCOUNT"] = 0 q.answers = [] - return q.to_bytes() + return q.to_bytes() ### RETURNING sname, stype, sclass = ( q.question["NAME"].decode("utf-8").lower(), q.question["QTYPE"], q.question["QCLASS"], ) + # TODO Handle ANY Query - """ - if stype == ANY_TYPE: - for record_type in [A, AAAA, ....]: - check if (sname, record_type, sclass) in cache: - append rec to q.answer - q.header["ANCOUNT] = len(q.answer) - return q - """ + # if stype == ANY_TYPE: + # for record_type in ANY_TYPE: + # check if (sname, record_type, sclass) in cache: + # append rec to q.answer + # q.header["ANCOUNT] = len(q.answer) + # return return self.recursive_lookup( q, sname, stype, sclass, datetime.now() ).to_bytes() # NOTE datetime.now() was not updating until I added in as a parameter def recursive_lookup(self, q, sname, stype, sclass, now=datetime.now(), limit=60): - if (datetime.now() - now) > timedelta(seconds=limit): # look up took longer than 100s - if not __debug__: - print("Query Timed Out [1]", q, datetime.now(), now) - q.header["QR"] = 1 - q.header["RA"] = 1 - q.header["RCODE"] = SERVFAIL_RCODE - q.header["ANCOUNT"] = 0 - q.header["NSCOUNT"] = 0 - q.header["ARCOUNT"] = 0 - q.answers = [] - return q - key = (sname, stype, sclass) ### STEP 1 ### a = self.check_Cache_ret_time(key, now) @@ -192,77 +240,19 @@ def recursive_lookup(self, q, sname, stype, sclass, now=datetime.now(), limit=60 (record["expire_time"] - datetime.now()).total_seconds() ) # UPDATE the TTL of records q.answers = [rec["resp"] for rec in a] - return q + return q ### RETURNING while True: - if (datetime.now() - now) > timedelta(seconds=limit): # look up took longer than 100s - if not __debug__: - print("Query Timed Out [2]", q, datetime.now(), now) - q.header["QR"] = 1 - q.header["RA"] = 1 - q.header["RCODE"] = SERVFAIL_RCODE - q.header["ANCOUNT"] = 0 - q.header["NSCOUNT"] = 0 - q.header["ARCOUNT"] = 0 - q.answers = [] - return q - - ### STEP 2 ### - sbelt = [ - "198.41.0.4", # root A - "199.9.14.201", # root B - "195.129.12.83", # dutch - # "1.1.1.1", # cloudflare - ] - mat_slist = defaultdict(list) - split_sname = sname.split(".") - for i in range(len(split_sname)): - slist = [] - reduced_sname = ".".join(split_sname[i:]) - if not reduced_sname: - continue - reduced_key = (reduced_sname, A_TYPE, sclass) - if a := self.check_Cache(reduced_key, now): # append the ip of name servers - slist.extend([rec["RDATA"][0] for rec in a]) - else: - reduced_key = (reduced_sname, NS_TYPE, sclass) - ns_records = self.check_Cache(reduced_key, now) - if ns_records: - for ns in ns_records: - ns_name = ns["RDATA"][0] - reduced_key = (ns_name.decode("utf-8").lower(), A_TYPE, sclass) - if a := self.check_Cache(reduced_key, now): - slist.extend([rec["RDATA"][0] for rec in a]) - if not slist: - # TODO Kick off parallel process to look for the ip addresses of NS server - # parallel_thread = threading.Thread(target=self.recursive_lookup, args=(q, *reduced_key, now, 30)) - # parallel_thread.start() - - ### SINGLE THREAD for now - original_name = q.question["NAME"] - for ns in ns_records: - ns_name = ns["RDATA"][0] - q.question["NAME"] = ns_name - reduced_key = (ns_name.decode("utf-8").lower(), A_TYPE, sclass) - _ = self.recursive_lookup(q, *reduced_key, now) - if a := self.check_Cache(reduced_key, now): - slist.extend([rec["RDATA"][0] for rec in a]) - - q.question["NAME"] = original_name - - mat_slist[reduced_sname] = slist - - mat_slist["."] = sbelt - if not __debug__: - for k, value in mat_slist.items(): - if k == ".": - continue - print(k, ":", [socket.inet_ntoa(v) for v in value]) - - slist = list(itertools.chain.from_iterable(mat_slist.values())) # flatten mat_slist + if (a := self.check_timeout(q, now, limit)) : + return a + slist = self.build_slist(q, *key, now) + if not __debug__: + print("SLIST: ", slist) ### STEP 3 ### for ns in slist: + name_server = str(ipaddress.ip_address(ns)) # DEBUGGING PURPOSES + a = self.query_then_cache(ns, q) ### STEP 4 ### @@ -272,22 +262,21 @@ def recursive_lookup(self, q, sname, stype, sclass, now=datetime.now(), limit=60 if ( (gotAns and a.header["RCODE"] == NOERROR_RCODE) # we found an Answer or a.header["RCODE"] == NAMEERROR_RCODE # there was a name error - or ( - a.answers and a.header["AA"] == 1 and [rec for rec in a.answers if rec["TYPE"] == SOA_TYPE] - ) # SOA response + or (a.header["AA"] == 1 and [rec for rec in a.answers if rec["TYPE"] == SOA_TYPE]) # SOA response or ( q.header["ANCOUNT"] and [rec for rec in a.answers if rec["TYPE"] == A_TYPE] ) # found an ans but name does not match quetion name ): + a.header["AA"] = 0 a.header["RA"] = 1 - return a + return a ### RETURNING ### STEP 4.2 ### - elif a.answers and ([rec for rec in a.answers if rec["TYPE"] == NS_TYPE]): # NSNAME + elif [rec for rec in a.answers if rec["TYPE"] == NS_TYPE]: # NSNAME break ### STEP 4.3 ### - elif a.answers and (cname_rec := [rec for rec in a.answers if rec["TYPE"] == CNAME_TYPE]): # CNAME + elif (cname_rec := [rec for rec in a.answers if rec["TYPE"] == CNAME_TYPE]) : # CNAME new_sname = random.choice(cname_rec)["RDATA"][0] original_name = q.question["NAME"] q.question["NAME"] = new_sname @@ -295,6 +284,57 @@ def recursive_lookup(self, q, sname, stype, sclass, now=datetime.now(), limit=60 resp.question["NAME"] = original_name return resp + def build_slist(self, q, sname, stype, sclass, now) -> list: + ### STEP 2 ### + sbelt = [ + "198.41.0.4", # root A + "199.9.14.201", # root B + "195.129.12.83", # dutch + ] + mat_slist = defaultdict(list) + split_sname = sname.split(".") + for i in range(len(split_sname)): + slist = [] + reduced_sname = ".".join(split_sname[i:]) + if not reduced_sname: + continue + reduced_key = (reduced_sname, A_TYPE, sclass) + if a := self.check_Cache(reduced_key, now): # append the ip of name servers + slist.extend([rec["RDATA"][0] for rec in a]) + + reduced_key = (reduced_sname, NS_TYPE, sclass) + ns_records = self.check_Cache(reduced_key, now) + if ns_records: + for ns in ns_records: + ns_name = ns["RDATA"][0] + reduced_key = (ns_name.decode("utf-8").lower(), A_TYPE, sclass) + if a := self.check_Cache(reduced_key, now): + slist.extend([rec["RDATA"][0] for rec in a]) + if not slist: + ### SINGLE THREAD for now + original_name = q.question["NAME"] + for ns in ns_records: + ns_name = ns["RDATA"][0] + q.question["NAME"] = ns_name + reduced_key = (ns_name.decode("utf-8").lower(), A_TYPE, sclass) + _ = self.recursive_lookup(q, *reduced_key, now) + if a := self.check_Cache(reduced_key, now): + slist.extend([rec["RDATA"][0] for rec in a]) + + q.question["NAME"] = original_name + + mat_slist[reduced_sname] = slist + + mat_slist["."] = sbelt + if not __debug__: + for k, value in mat_slist.items(): + if k == ".": + continue + print(k, ":", [socket.inet_ntoa(v) for v in value]) + + slist = list(itertools.chain.from_iterable(mat_slist.values())) # flatten mat_slist + return slist + parser = argparse.ArgumentParser(description="""This is a DNS resolver""") parser.add_argument( From 9acb0b447e5c1063b7d27d3e388e7aecbe5e9b2e Mon Sep 17 00:00:00 2001 From: JasonCheng1 Date: Fri, 7 May 2021 11:30:17 -0400 Subject: [PATCH 3/3] final commit before submit --- my_server.py | 40 +++++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/my_server.py b/my_server.py index 27993ca..89c638c 100644 --- a/my_server.py +++ b/my_server.py @@ -2,7 +2,6 @@ from resolver_backround import DnsResolver import threading import socket -import struct import argparse from sys import argv from helper_funcs import DNSQuery @@ -29,7 +28,7 @@ MX_TYPE = 15 TXT_TYPE = 16 ANY_TYPE = 255 -# ANY_TYPE = [A_TYPE, NS_TYPE, CNAME_TYPE, SOA_TYPE, WKS_TYPE, PTR_TYPE, HINFO_TYPE, MINFO_TYPE, MX_TYPE, TXT_TYPE] +TYPES = [A_TYPE, NS_TYPE, CNAME_TYPE, SOA_TYPE, WKS_TYPE, PTR_TYPE, HINFO_TYPE, MINFO_TYPE, MX_TYPE, TXT_TYPE] # RCODE Values NOERROR_RCODE = 0 @@ -67,8 +66,6 @@ def __init__(self, port): self.cache = LRUCache(CACHE_SIZE) def check_Cache(self, key, now): - # if key[1] == ANY_TYPE: - # return None a = self.cache.get(key) @@ -135,7 +132,7 @@ def query_then_cache(self, destination_address_1, q): if not __debug__: print(f"Q {destination_address_1}: ", q) - read, _, _ = select.select([sock], [], [], 15) # Timeout after 15 sec + read, _, _ = select.select([sock], [], [], 8) # Timeout after 15 sec if read: answer, (destination_address_2, destination_port_2) = sock.recvfrom(1024) else: @@ -179,6 +176,7 @@ def query_then_cache(self, destination_address_1, q): self.cache.put(key, val) # TODO Support Negative Caching + # https://tools.ietf.org/html/rfc2308#section-5 and section-6 # else a.header["RCODE"] == NXDOMAIN ... if not __debug__: @@ -214,13 +212,24 @@ def get_dns_response(self, query): q.question["QCLASS"], ) - # TODO Handle ANY Query - # if stype == ANY_TYPE: - # for record_type in ANY_TYPE: - # check if (sname, record_type, sclass) in cache: - # append rec to q.answer - # q.header["ANCOUNT] = len(q.answer) - # return + ### TODO Handle norecurse + # + + ### Handle ANY Query + # Traverse through cache get anything that matches with sname + if stype == ANY_TYPE: + q.header["QR"] = 1 + q.header["AA"] = 0 # Because result is from a cache + q.header["RA"] = 1 + now = datetime.now() + for record_type in TYPES: + new_key = (sname, record_type, sclass) + if (a := self.check_Cache_ret_time(new_key, now)) : + for record in a: + record["resp"]["TTL"] = int((record["expire_time"] - now).total_seconds()) + q.answers.extend([rec["resp"] for rec in a]) + q.header["ANCOUNT"] = len(q.answers) + return q.to_bytes() return self.recursive_lookup( q, sname, stype, sclass, datetime.now() @@ -229,8 +238,7 @@ def get_dns_response(self, query): def recursive_lookup(self, q, sname, stype, sclass, now=datetime.now(), limit=60): key = (sname, stype, sclass) ### STEP 1 ### - a = self.check_Cache_ret_time(key, now) - if a: + if (a := self.check_Cache_ret_time(key, now)) : q.header["QR"] = 1 q.header["ANCOUNT"] = len(a) q.header["AA"] = 0 # Because result is from a cache @@ -243,7 +251,7 @@ def recursive_lookup(self, q, sname, stype, sclass, now=datetime.now(), limit=60 return q ### RETURNING while True: - if (a := self.check_timeout(q, now, limit)) : + if (a := self.check_timeout(q, now, limit)) : # Query has taken too long return a slist = self.build_slist(q, *key, now) @@ -251,8 +259,6 @@ def recursive_lookup(self, q, sname, stype, sclass, now=datetime.now(), limit=60 print("SLIST: ", slist) ### STEP 3 ### for ns in slist: - name_server = str(ipaddress.ip_address(ns)) # DEBUGGING PURPOSES - a = self.query_then_cache(ns, q) ### STEP 4 ###