I am studying sklearn
and I write a class Classifier
to do common classification. It need a method
to determine using which Estimator:
# Classifier
from sklearn.svm import SVC
from sklearn.svm import LinearSVC
from sklearn.linear_model import SGDClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
class Classifier(object):
def __init__(self, method='LinearSVC', *args, **kwargs):
Estimator = getattr(**xxx**, method, None)
self.Estimator = Estimator
self._model = Estimator(*args, **kwargs)
def fit(self, data, target):
return self._model.fit(data, target)
def predict(self, data):
return self._model.predict(data)
def score(self, X, y, sample_weight=None):
return self._model.score(X, y, sample_weight=None)
def persist_model(self):
pass
def get_model(self):
return self._model
def classification_report(self, expected, predicted):
return metrics.classification_report(expected, predicted)
def confusion_matrix(self, expected, predicted):
return metrics.confusion_matrix(expected, predicted)
I want to get Estimator by name, but what xxx should be?
Or is there a better way to do this?
Build a dict to store the imported module? but this way seems not so good..
In this case it is advised to simply use the class directly as an argument.
You will never have to worry about it as a string: you can compare
LinearSVC is LinearSVC
, and compare it to something else.Think of it like accepting an integer as argument, then converting it to a string to use it: does that make sense? You can just simply require a string.
Proposed code:
You can then do:
As per the comment:
You can then also initialise a dict at start like:
It's cleaner to work with a
KeyError
(the string/model does not exist, and you're aware since you did not define them), than to checkglobals
, sounds pretty nasty!