-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathDataSet.py
More file actions
33 lines (26 loc) · 1.02 KB
/
DataSet.py
File metadata and controls
33 lines (26 loc) · 1.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Importing the dataset.
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
# Creating independent and dependent variables.
X, y = mnist['data'], mnist['target']
# Splitting the data into training set and test set.
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
"""
The training set is already shuffled for us, which is good as this guarantees that all
cross-validation folds will be similar.
"""
# Training a binary classifier.
y_train_5 = (y_train == 5) # True for all 5s, False for all other digits.
y_test_5 = (y_test == 5)
"""
Building a dumb classifier that just classifies every single image in the “not-5” class.
"""
from sklearn.model_selection import cross_val_score
from sklearn.base import BaseEstimator
class Never5Classifier(BaseEstimator):
def fit(self, X, y=None):
pass
def predict(self, X):
return np.zeros((len(X), 1), dtype=bool)
never_5_clf = Never5Classifier()
cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring="accuracy")