|
1 | 1 | import logging
|
2 | 2 | import random
|
3 |
| - |
| 3 | +import os |
4 | 4 | import numpy as np
|
5 | 5 | import pandas as pd
|
6 | 6 |
|
7 | 7 | from gensim.models import doc2vec
|
8 | 8 | from sklearn.linear_model import LogisticRegression
|
9 | 9 | from sklearn.model_selection import train_test_split
|
10 | 10 | from sklearn.metrics import accuracy_score, f1_score
|
| 11 | +from sklearn.externals import joblib |
11 | 12 |
|
12 | 13 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
|
13 | 14 |
|
@@ -99,9 +100,21 @@ def test_classifier(d2v, classifier, testing_vectors, testing_labels):
|
99 | 100 | logging.info('Testing accuracy: {}'.format(accuracy_score(testing_labels, testing_predictions)))
|
100 | 101 | logging.info('Testing F1 score: {}'.format(f1_score(testing_labels, testing_predictions, average='weighted')))
|
101 | 102 |
|
| 103 | +def save_classifier(model, filename): |
| 104 | + joblib.dump(model, filename) |
| 105 | + |
| 106 | +def load_classifier(filename): |
| 107 | + if(os.path.isfile('./' +filename)): |
| 108 | + loaded_model = joblib.load(filename) |
| 109 | + return loaded_model |
| 110 | + else: |
| 111 | + return None |
102 | 112 |
|
103 | 113 | if __name__ == "__main__":
|
104 | 114 | x_train, x_test, y_train, y_test, all_data = read_dataset('dataset.csv')
|
105 | 115 | d2v_model = train_doc2vec(all_data)
|
106 | 116 | classifier = train_classifier(d2v_model, x_train, y_train)
|
107 |
| - test_classifier(d2v_model, classifier, x_test, y_test) |
| 117 | + joblib_file = "joblib_model.pkl" |
| 118 | + save_classifier(classifier, joblib_file) |
| 119 | + model = load_classifier(joblib_file) |
| 120 | + test_classifier(d2v_model, model, x_test, y_test) |
0 commit comments