-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdecision_tree.py
More file actions
49 lines (42 loc) · 1.79 KB
/
decision_tree.py
File metadata and controls
49 lines (42 loc) · 1.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#-------------------------------------------------------------------------
# AUTHOR: Nicholas Hoang
# FILENAME: decision_tree.py
# SPECIFICATION: Outputs a decision tree based on the contact_lens.csv file, which contains data on patients who do and do not need contact lenses
# FOR: CS 4210- Assignment #1
# TIME SPENT: 4 hours
#-----------------------------------------------------------*/
#IMPORTANT NOTE: DO NOT USE ANY ADVANCED PYTHON LIBRARY TO COMPLETE THIS CODE SUCH AS numpy OR pandas. You have to work here only with standard
# dictionaries, lists, and arrays
#importing some Python libraries
from sklearn import tree
import matplotlib.pyplot as plt
import csv
db = []
X = []
Y = []
#reading the data in a csv file
with open('contact_lens.csv', 'r') as csvfile:
reader = csv.reader(csvfile)
for i, row in enumerate(reader):
if i > 0: #skipping the header
db.append (row)
print(row)
#transform the original categorical training features into numbers and add to the 4D array X. For instance Young = 1, Prepresbyopic = 2, Presbyopic = 3
#--> add your Python code here
for row in db:
X.append([
1 if row[0] == 'Young' else 2 if row[0] == 'Prepresbyopic' else 3,
1 if row[1] == 'Myope' else 2,
1 if row[2] == 'Yes' else 2,
1 if row[3] == 'Normal' else 2
])
#transform the original categorical training classes into numbers and add to the vector Y. For instance Yes = 1, No = 2
#--> add your Python code here
for row in db:
Y.append(1 if row[4] == 'Yes' else 2)
#fitting the decision tree to the data
clf = tree.DecisionTreeClassifier(criterion = 'entropy')
clf = clf.fit(X, Y)
# #plotting the decision tree
tree.plot_tree(clf, feature_names=['Age', 'Spectacle', 'Astigmatism', 'Tear'], class_names=['Yes','No'], filled=True, rounded=True)
plt.show()