|
| 1 | +from __future__ import print_function, division |
| 2 | +from time import time |
| 3 | +import argparse |
| 4 | +import numpy as np |
| 5 | + |
| 6 | +from sklearn.dummy import DummyClassifier |
| 7 | + |
| 8 | +from sklearn.datasets import fetch_20newsgroups_vectorized |
| 9 | +from sklearn.metrics import accuracy_score |
| 10 | +from sklearn.utils.validation import check_array |
| 11 | + |
| 12 | +from sklearn.ensemble import RandomForestClassifier |
| 13 | +from sklearn.ensemble import ExtraTreesClassifier |
| 14 | +from sklearn.ensemble import AdaBoostClassifier |
| 15 | +from sklearn.linear_model import LogisticRegression |
| 16 | +from sklearn.naive_bayes import MultinomialNB |
| 17 | + |
| 18 | +ESTIMATORS = { |
| 19 | + "dummy": DummyClassifier(), |
| 20 | + "random_forest": RandomForestClassifier(n_estimators=100, |
| 21 | + max_features="sqrt", |
| 22 | + min_samples_split=10), |
| 23 | + "extra_trees": ExtraTreesClassifier(n_estimators=100, |
| 24 | + max_features="sqrt", |
| 25 | + min_samples_split=10), |
| 26 | + "logistic_regression": LogisticRegression(), |
| 27 | + "naive_bayes": MultinomialNB(), |
| 28 | + "adaboost": AdaBoostClassifier(n_estimators=10), |
| 29 | +} |
| 30 | + |
| 31 | + |
| 32 | +############################################################################### |
| 33 | +# Data |
| 34 | + |
| 35 | +if __name__ == "__main__": |
| 36 | + |
| 37 | + parser = argparse.ArgumentParser() |
| 38 | + parser.add_argument('-e', '--estimators', nargs="+", required=True, |
| 39 | + choices=ESTIMATORS) |
| 40 | + args = vars(parser.parse_args()) |
| 41 | + |
| 42 | + data_train = fetch_20newsgroups_vectorized(subset="train") |
| 43 | + data_test = fetch_20newsgroups_vectorized(subset="test") |
| 44 | + X_train = check_array(data_train.data, dtype=np.float32, |
| 45 | + accept_sparse="csc") |
| 46 | + X_test = check_array(data_test.data, dtype=np.float32, accept_sparse="csr") |
| 47 | + y_train = data_train.target |
| 48 | + y_test = data_test.target |
| 49 | + |
| 50 | + print("20 newsgroups") |
| 51 | + print("=============") |
| 52 | + print("X_train.shape = {0}".format(X_train.shape)) |
| 53 | + print("X_train.format = {0}".format(X_train.format)) |
| 54 | + print("X_train.dtype = {0}".format(X_train.dtype)) |
| 55 | + print("X_train density = {0}" |
| 56 | + "".format(X_train.nnz / np.product(X_train.shape))) |
| 57 | + print("y_train {0}".format(y_train.shape)) |
| 58 | + print("X_test {0}".format(X_test.shape)) |
| 59 | + print("X_test.format = {0}".format(X_test.format)) |
| 60 | + print("X_test.dtype = {0}".format(X_test.dtype)) |
| 61 | + print("y_test {0}".format(y_test.shape)) |
| 62 | + print() |
| 63 | + |
| 64 | + print("Classifier Training") |
| 65 | + print("===================") |
| 66 | + accuracy, train_time, test_time = {}, {}, {} |
| 67 | + for name in sorted(args["estimators"]): |
| 68 | + clf = ESTIMATORS[name] |
| 69 | + try: |
| 70 | + clf.set_params(random_state=0) |
| 71 | + except (TypeError, ValueError): |
| 72 | + pass |
| 73 | + |
| 74 | + print("Training %s ... " % name, end="") |
| 75 | + t0 = time() |
| 76 | + clf.fit(X_train, y_train) |
| 77 | + train_time[name] = time() - t0 |
| 78 | + t0 = time() |
| 79 | + y_pred = clf.predict(X_test) |
| 80 | + test_time[name] = time() - t0 |
| 81 | + accuracy[name] = accuracy_score(y_test, y_pred) |
| 82 | + print("done") |
| 83 | + |
| 84 | + print() |
| 85 | + print("Classification performance:") |
| 86 | + print("===========================") |
| 87 | + print() |
| 88 | + print("%s %s %s %s" % ("Classifier ", "train-time", "test-time", |
| 89 | + "Accuracy")) |
| 90 | + print("-" * 44) |
| 91 | + for name in sorted(accuracy, key=accuracy.get): |
| 92 | + print("%s %s %s %s" % (name.ljust(16), |
| 93 | + ("%.4fs" % train_time[name]).center(10), |
| 94 | + ("%.4fs" % test_time[name]).center(10), |
| 95 | + ("%.4f" % accuracy[name]).center(10))) |
| 96 | + |
| 97 | + print() |
0 commit comments