-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdecisionTree1.py
More file actions
23 lines (19 loc) · 901 Bytes
/
decisionTree1.py
File metadata and controls
23 lines (19 loc) · 901 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.inspection import DecisionBoundaryDisplay
import numpy as np
import matplotlib.pyplot as plt
iris = load_iris()
n_classes=len(iris.target_names)
plot_colors="ryb"
for pairidx , pair in enumerate([[0,1],[0,2],[0,3],[1,2],[1,3],[2,3]]):
X = iris.data[:, pair]
y= iris.target
clf = DecisionTreeClassifier().fit(X,y)
ax= plt.subplot(2,3,pairidx+1)
plt.tight_layout(h_pad=0.5, w_pad=0.5,pad=2.5)
DecisionBoundaryDisplay.from_estimator(clf,X,cmap=plt.cm.RdYlBu,response_method="predict",ax=ax,xlabel=iris.feature_names[pair[0]],ylabel=iris.feature_names[pair[1]])
for i ,color in zip(range(n_classes),plot_colors):
idx=np.where(y==i)
plt.scatter(X[idx,0], X[idx,1],c=color,label=iris.target_names[i],cmap=plt.cm.RdYlBu,edgecolors="black")
plt.show()