Scikit-learn: overriding a class method in a classifier

1.7k views Asked by At

I am trying to override the predict_proba method of a classifier class. The easiest approach as far as I have seen and if applicable is preprocessing the input to the base class method or postprocessing its output.

class RandomForestClassifierWrapper(RandomForestClassifier):

    def predict_proba(self, X):
        pre_process(X)
        ret = super(RandomForestClassifierWrapper, self).predict_proba(X)
        return post_process(ret)

However, what I want to do is copying a variable which is locally created in the base class method, processed and garbage-collected when the method returns. I am intending to process the intermediate result stored in this variable. Is there a straightforward way to do this without messing with the base class internals?

2

There are 2 answers

3
thodic On BEST ANSWER

Try overriding:

class RandomForestClassifierWrapper(RandomForestClassifier):

    def predict_proba(self, X):
            check_is_fitted(self, 'n_outputs_')

            # Check data
            X = check_array(X, dtype=DTYPE, accept_sparse="csr")

            # Assign chunk of trees to jobs
            n_jobs, n_trees, starts = _partition_estimators(self.n_estimators,
                                                            self.n_jobs)

            # Parallel loop
            all_proba = Parallel(n_jobs=n_jobs, verbose=self.verbose,
                                 backend="threading")(

            # do something with all_proba

            return all_proba
5
hajtos On

There is no way to access local variables of a method from the outside. What you could do, since you have the code of the base classifier, is overwrite the predict_proba method by copying the code from the base classifier and handling the local variables however you want.