Get already imported module in current file by name

77 views Asked by At

I am studying sklearn and I write a class Classifier to do common classification. It need a method to determine using which Estimator:

# Classifier
from sklearn.svm import SVC
from sklearn.svm import LinearSVC
from sklearn.linear_model import SGDClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier

class Classifier(object):
    def __init__(self, method='LinearSVC', *args, **kwargs):
        Estimator = getattr(**xxx**, method, None)
        self.Estimator = Estimator
        self._model = Estimator(*args, **kwargs)

    def fit(self, data, target):
        return self._model.fit(data, target)

    def predict(self, data):
        return self._model.predict(data)

    def score(self, X, y, sample_weight=None):
        return self._model.score(X, y, sample_weight=None)

    def persist_model(self):
        pass

    def get_model(self):
        return self._model

    def classification_report(self, expected, predicted):
        return metrics.classification_report(expected, predicted)

    def confusion_matrix(self, expected, predicted):
        return metrics.confusion_matrix(expected, predicted)

I want to get Estimator by name, but what xxx should be? Or is there a better way to do this?
Build a dict to store the imported module? but this way seems not so good..

3

There are 3 answers

2
PascalVKooten On

In this case it is advised to simply use the class directly as an argument.

You will never have to worry about it as a string: you can compare LinearSVC is LinearSVC, and compare it to something else.

Think of it like accepting an integer as argument, then converting it to a string to use it: does that make sense? You can just simply require a string.

Proposed code:

class Classifier(object):
    def __init__(self, model = LinearSVC, *args, **kwargs):
        self._model = model(*args, **kwargs)

You can then do:

myclf = Classifier(..., estimator = LinearSVC, ...)
isinstance(myclf._model, LinearSVC)

As per the comment:

You can then also initialise a dict at start like:

from sklearn.svm import LinearSVC

str_to_model = {'LinearSVC' : LinearSVC}

class Classifier(object):
    def __init__(self, model = "LinearSVC", *args, **kwargs):
        self._model = str_to_model[model](*args, **kwargs)

It's cleaner to work with a KeyError (the string/model does not exist, and you're aware since you did not define them), than to check globals, sounds pretty nasty!

1
Stefano M On

Built in function globals() does the trick: you can check that globals()['LogisticRegression'] is LogisticRegression returns True.

ADDENDUM

  • Safe: nothing 'nasty' can happen by evaluating globals()[method]
  • Efficient: overhead is negligible with respect to some_method_dict[method]
  • Simple: globals()[method] is just the shortest answer to the question.

If this is pythonic or not, I don't know, but the globals() builtin is there to be used, so why chose more complicated solutions?

To be explicit,

Estimator = getattr(..., method, None)

can be implemented as

Estimator = globals().get(method)

if the None return is preferred to a KeyError exception if method was not imported.

0
Guy On

There are two build-in functions may help you: globals and locals, both of the two return a dict of the current symbol table.

you code could be Estimator = globals()[method] or mv the estimator keys to __init__ and use Estimator = locals()[method]