Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
276 changes: 197 additions & 79 deletions my_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
from resolver_backround import DnsResolver
import threading
import socket
import struct
import argparse
from sys import argv
from time import sleep
from helper_funcs import DNSQuery
from datetime import datetime, timedelta
from collections import defaultdict, OrderedDict
import copy
import random
import ipaddress
import copy
import itertools
import select

MAX_LEVEL = 10 # max number of iterative queries we make before we stop
CACHE_SIZE = 100000
Expand All @@ -20,9 +20,21 @@
A_TYPE = 1
NS_TYPE = 2
CNAME_TYPE = 5

SOA_TYPE = 6
WKS_TYPE = 11
PTR_TYPE = 12
HINFO_TYPE = 13
MINFO_TYPE = 14
MX_TYPE = 15
TXT_TYPE = 16
ANY_TYPE = 255
TYPES = [A_TYPE, NS_TYPE, CNAME_TYPE, SOA_TYPE, WKS_TYPE, PTR_TYPE, HINFO_TYPE, MINFO_TYPE, MX_TYPE, TXT_TYPE]

# RCODE Values
NOERROR_RCODE = 0
SERVFAIL_RCODE = 2
NAMEERROR_RCODE = 3
REFUSED_RCODE = 5


class LRUCache:
Expand Down Expand Up @@ -54,7 +66,9 @@ def __init__(self, port):
self.cache = LRUCache(CACHE_SIZE)

def check_Cache(self, key, now):

a = self.cache.get(key)

if a:
a = [record for record in a if record["expire_time"] > now] # Get rid of expired records
self.cache.put(key, a if a else None)
Expand All @@ -70,27 +84,86 @@ def check_Cache_ret_time(self, key, now):
return a
return None

def query_then_cache(self, name_server, q):
def check_timeout(self, q, now, limit):
if (datetime.now() - now) > timedelta(seconds=limit):
if not __debug__:
print("Query Timed Out [1]", q, datetime.now(), now)
q.header["QR"] = 1
q.header["RCODE"] = SERVFAIL_RCODE
q.header["ANCOUNT"] = 0
q.header["NSCOUNT"] = 0
q.header["ARCOUNT"] = 0
q.answers = []
return q
return None

def query_then_cache(self, destination_address_1, q):
### Query
### https://tools.ietf.org/html/rfc5452#section-9

destination_port_1 = 53

old_id = copy.copy(q.header["ID"])

q.header["ID"] = random.randint(0, 65535)

q_id_1, q_name_1, q_type_1, q_class_1 = (
q.header["ID"],
q.question["NAME"],
q.question["QTYPE"],
q.question["QCLASS"],
)
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("", 8000))
name_server = str(ipaddress.ip_address(name_server))
sock.connect((name_server, 53))

sock.bind(("", 0))
destination_address_1 = str(ipaddress.ip_address(destination_address_1))

sock.connect((destination_address_1, destination_port_1))

q.header["RD"] = 0 # we do not want recursive query
q.header["QR"] = 0
q.header["ANCOUNT"] = 0
q.header["NSCOUNT"] = 0
q.header["NSCOUNT"] = 0
q.header["ARCOUNT"] = 0
q.answers = []
sock.sendall(q.to_bytes())
answer = sock.recv(512)
if not __debug__:
print(f"Q {destination_address_1}: ", q)

read, _, _ = select.select([sock], [], [], 8) # Timeout after 15 sec
if read:
answer, (destination_address_2, destination_port_2) = sock.recvfrom(1024)
else:
q.header["ID"] = old_id
return q

if destination_address_1 != destination_address_2 or destination_port_1 != destination_port_2:
q.header["ID"] = old_id
return q # address/port do not match

sock.close()

a = DNSQuery(answer)
### Caching

q_id_2, q_name_2, q_type_2, q_class_2 = (
a.header["ID"],
a.question["NAME"],
a.question["QTYPE"],
a.question["QCLASS"],
)

if q_id_1 != q_id_2 or q_name_1 != q_name_2 or q_type_1 != q_type_2 or q_class_1 != q_class_2:
q.header["ID"] = old_id
return q # id/name/type/class do not match

### Caching
if a.header["RCODE"] == NOERROR_RCODE:
new_rr = defaultdict(list)
for record in a.answers:
key = (
record["NAME"].decode("utf-8"),
record["NAME"].decode("utf-8").lower(),
record["TYPE"],
record["CLASS"],
)
Expand All @@ -101,102 +174,89 @@ def query_then_cache(self, name_server, q):
new_rr[key].append(val)
for key, val in new_rr.items():
self.cache.put(key, val)
print("%%")
# print([[str(num) for num in record["RDATA"]] for record in a.answers])
print(a)

# TODO Support Negative Caching
# https://tools.ietf.org/html/rfc2308#section-5 and section-6
# else a.header["RCODE"] == NXDOMAIN ...

if not __debug__:
print("A: ", a)

q.header["ID"] = old_id
a.header["ID"] = old_id
return a

def get_dns_response(self, query):
# input: A query and any state in self
# returns: the correct response to the query obtained by asking DNS name servers
# Your code goes here, when you change any 'self' variables make sure to use a lock
print("***")
print(query)
print("&&&")
q = DNSQuery(query)
print(q)
if not __debug__:
print("Query: ", q)

### Reject EDNS Reference: https://tools.ietf.org/html/rfc6891#section-6
if q.header["ARCOUNT"] and any(rec["TYPE"] == 41 for rec in q.answers):
q.header["QR"] = 1
q.header["RCODE"] = 4
q.header["RD"] = 0
q.header["RA"] = 0
q.header["RCODE"] = REFUSED_RCODE
q.header["ANCOUNT"] = 0
q.header["NSCOUNT"] = 0
q.header["ARCOUNT"] = 0
q.answers = []
return q.to_bytes()
return q.to_bytes() ### RETURNING

sname, stype, sclass = (
q.question["NAME"].decode("utf-8"),
q.question["NAME"].decode("utf-8").lower(),
q.question["QTYPE"],
q.question["QCLASS"],
)

return self.recursive_lookup(q, sname, stype, sclass).to_bytes()
### TODO Handle norecurse
#

def recursive_lookup(self, q, sname, stype, sclass, now=datetime.now(), limit=120):
if (datetime.now() - now) > timedelta(seconds=limit): # look up took longer than 100s
print("Query Timed Out")
### Handle ANY Query
# Traverse through cache get anything that matches with sname
if stype == ANY_TYPE:
q.header["QR"] = 1
q.header["RCODE"] = 2
q.header["ANCOUNT"] = 0
q.header["NSCOUNT"] = 0
q.header["ARCOUNT"] = 0
q.answers = []
return q
q.header["AA"] = 0 # Because result is from a cache
q.header["RA"] = 1
now = datetime.now()
for record_type in TYPES:
new_key = (sname, record_type, sclass)
if (a := self.check_Cache_ret_time(new_key, now)) :
for record in a:
record["resp"]["TTL"] = int((record["expire_time"] - now).total_seconds())
q.answers.extend([rec["resp"] for rec in a])
q.header["ANCOUNT"] = len(q.answers)
return q.to_bytes()

return self.recursive_lookup(
q, sname, stype, sclass, datetime.now()
).to_bytes() # NOTE datetime.now() was not updating until I added in as a parameter

def recursive_lookup(self, q, sname, stype, sclass, now=datetime.now(), limit=60):
key = (sname, stype, sclass)
### STEP 1 ###
a = self.check_Cache_ret_time(key, now)
if a:
if (a := self.check_Cache_ret_time(key, now)) :
q.header["QR"] = 1
q.header["ANCOUNT"] = len(a)
q.header["AA"] = 0 # Because result is from a cache
q.header["RA"] = 1
for record in a:
record["resp"]["TTL"] = int((record["expire_time"] - datetime.now()).total_seconds())
record["resp"]["TTL"] = int(
(record["expire_time"] - datetime.now()).total_seconds()
) # UPDATE the TTL of records
q.answers = [rec["resp"] for rec in a]
return q
return q ### RETURNING

while True:
### STEP 2 ###
sbelt = [
"198.41.0.4",
"199.9.14.201",
"195.129.12.83",
] # A and B root server
slist = []
split_sname = sname.split(".")
for i in range(len(split_sname)):
reduced_sname = ".".join(split_sname[i:])
reduced_key = (reduced_sname, NS_TYPE, sclass)
ns_records = self.check_Cache(reduced_key, now)
if ns_records:
for ns in ns_records:
ns_name = ns["RDATA"][0]
reduced_key = (ns_name.decode("utf-8"), A_TYPE, sclass)
if a := self.check_Cache(reduced_key, now): # append the ip of name servers
random_a = random.choice(a)
slist.extend(random_a["RDATA"])
if ns_records and not slist:
# TODO Kick off parallel process to look for the ip addresses of NS server
# parallel_thread = threading.Thread(target=self.recursive_lookup, args=(q, *reduced_key, now, 30))
# parallel_thread.start()

### SINGLE THREAD For now
original_name = q.question["NAME"]
for ns in ns_records:
ns_name = ns["RDATA"][0]
q.question["NAME"] = ns_name
reduced_key = (ns_name.decode("utf-8"), A_TYPE, sclass)
_ = self.recursive_lookup(q, *reduced_key, now, 60)
a = self.check_Cache(reduced_key, now)
if a:
slist.extend(random.choice(a)["RDATA"])
q.question["NAME"] = original_name

slist.extend(sbelt)
if (a := self.check_timeout(q, now, limit)) : # Query has taken too long
return a

slist = self.build_slist(q, *key, now)
if not __debug__:
print("SLIST: ", slist)
### STEP 3 ###
for ns in slist:
a = self.query_then_cache(ns, q)
Expand All @@ -205,24 +265,82 @@ def recursive_lookup(self, q, sname, stype, sclass, now=datetime.now(), limit=12
gotAns = self.check_Cache(key, now)

### STEP 4.1 ###
if (gotAns and a.header["RCODE"] == NOERROR_RCODE) or a.header[
"RCODE"
] == NAMEERROR_RCODE: # We found it or there's an ans
return a
if (
(gotAns and a.header["RCODE"] == NOERROR_RCODE) # we found an Answer
or a.header["RCODE"] == NAMEERROR_RCODE # there was a name error
or (a.header["AA"] == 1 and [rec for rec in a.answers if rec["TYPE"] == SOA_TYPE]) # SOA response
or (
q.header["ANCOUNT"] and [rec for rec in a.answers if rec["TYPE"] == A_TYPE]
) # found an ans but name does not match quetion name
):
a.header["AA"] = 0
a.header["RA"] = 1
return a ### RETURNING

### STEP 4.2 ###
elif a.answers and (ns_records := [rec for rec in a.answers if rec["TYPE"] == NS_TYPE]): # NSNAME
elif [rec for rec in a.answers if rec["TYPE"] == NS_TYPE]: # NSNAME
break

### STEP 4.3 ###
elif a.answers and (cname_rec := [rec for rec in a.answers if rec["TYPE"] == CNAME_TYPE]): # CNAME
elif (cname_rec := [rec for rec in a.answers if rec["TYPE"] == CNAME_TYPE]) : # CNAME
new_sname = random.choice(cname_rec)["RDATA"][0]
original_name = q.question["NAME"]
q.question["NAME"] = new_sname
resp = self.recursive_lookup(q, new_sname.decode("utf-8"), stype, sclass, now)
resp = self.recursive_lookup(q, new_sname.decode("utf-8").lower(), stype, sclass, now)
resp.question["NAME"] = original_name
return resp

def build_slist(self, q, sname, stype, sclass, now) -> list:
### STEP 2 ###
sbelt = [
"198.41.0.4", # root A
"199.9.14.201", # root B
"195.129.12.83", # dutch
]
mat_slist = defaultdict(list)
split_sname = sname.split(".")
for i in range(len(split_sname)):
slist = []
reduced_sname = ".".join(split_sname[i:])
if not reduced_sname:
continue
reduced_key = (reduced_sname, A_TYPE, sclass)
if a := self.check_Cache(reduced_key, now): # append the ip of name servers
slist.extend([rec["RDATA"][0] for rec in a])

reduced_key = (reduced_sname, NS_TYPE, sclass)
ns_records = self.check_Cache(reduced_key, now)
if ns_records:
for ns in ns_records:
ns_name = ns["RDATA"][0]
reduced_key = (ns_name.decode("utf-8").lower(), A_TYPE, sclass)
if a := self.check_Cache(reduced_key, now):
slist.extend([rec["RDATA"][0] for rec in a])
if not slist:
### SINGLE THREAD for now
original_name = q.question["NAME"]
for ns in ns_records:
ns_name = ns["RDATA"][0]
q.question["NAME"] = ns_name
reduced_key = (ns_name.decode("utf-8").lower(), A_TYPE, sclass)
_ = self.recursive_lookup(q, *reduced_key, now)
if a := self.check_Cache(reduced_key, now):
slist.extend([rec["RDATA"][0] for rec in a])

q.question["NAME"] = original_name

mat_slist[reduced_sname] = slist

mat_slist["."] = sbelt
if not __debug__:
for k, value in mat_slist.items():
if k == ".":
continue
print(k, ":", [socket.inet_ntoa(v) for v in value])

slist = list(itertools.chain.from_iterable(mat_slist.values())) # flatten mat_slist
return slist


parser = argparse.ArgumentParser(description="""This is a DNS resolver""")
parser.add_argument(
Expand Down