I'm working on a project that involves implementing some algorithms as python classes and testing their performance. I decided to write them up as sklearn estimators so that I could use GridSearchCV for validation.
However, one of my algorithms for Inductive Matrix Completion takes more than just X and y as arguments. This becomes a problem for the GridSearchCV.fit as there appears to be no way to pass more than just X and y to the fit method of the estimator. The source shows the following arguments for GridSearchCV.fit:
def fit(self, X, y=None, groups=None, **fit_params):
And of course the downstream methods expect only these two arguments. Obviously it would be no trivial task (or advisable) to modify my local copy of GridSearchCV to accommodate my needs.
For reference IMC basically states that $ R \approx XW^THY^T $. So my fit method takes the following form:
def fit(self, R, X, Y):
So trying the following fails as the Y value never gets passed to the IMC.fit method:
imc = IMC()
params = {...}
gs = GridSearchCV(imc, param_grid=params)
gs.fit(R, X, Y)
I've created a workaround for this by modifying the IMC.fit method like so (this also has to be inserted into the score method):
def fit(self, R, X, Y=None):
if Y is None:
split = np.where(np.all(X == 999, axis=0))[0][0]
Y = X[:, split + 1:]
X = X[:, :split]
...
This allows me to use numpy.hstack to stack X and Y horizontally and insert a column of all 999 between them. This array can then be passed to GridSearchCV.fit as follows:
data = np.hstack([X, np.ones((X.shape[0],1)) * 999, Y])
gs.fit(R, data)
This approach works, but feels pretty hacky. Therefore my question is this:
So after getting some inspiration from a friend on this (@Matthew Drury) I constructed a much more elegant solution.
Again the problem is framed as such:
I have a matrix completion method that takes
X,Y, andRas arguments and attempts to constructWandHthat minimizeR - XWHYfor all observed indices inR. A basic implementation of afitmethod would look like this:This doesn't fit well into the standard sklearn model where fit takes an
X(the features that feed into a model) andy(the results) and looks like this:This isn't really an issue until you start using
GridSearchCVor other cross validation methods as they expect the data to fit the latter format. So to marry these two concepts I needed a way of packaging two disparate matricesXandYinto a single structure without losing the separate nature of the two.In the 5 minutes I had to dedicate to this originally I came up with the hacky solution. In a matrix
Rshapen, mwhere the rows correspond to the records inXand the columns correspond to the records inY, there arebtotal entries. If we take the row and column indices for all of these entries and indexXon the rows andYon the columns we will end up with equal length matrices forXandY. These can then be stacked horizontally, separated by a column of nonsense, and passed to the cross validation methods without issue (we just need a couple helper methods inside the original class to reconstruct the originalXandYfrom the stack before fitting.The point of this question was to find the elegant solution, or preferably an existing solution. That doesn't seem to be the case so I will propose the following model for any future estimators/classifiers built inheriting from sklearn that require more than just a single feature matrix for the fit method.
Create a DataHandler
When using
GridSearchCVthefitmethod does a round of checks before firing off any calls to the estimatorsfitmethod. One of these determines if the passedXarray is indexable. This test basically checks ifXimplements__getitem__orilocand is the same length asy. This length check requiresXto have ashapeattribute. At that point the split indices and fits can be computed as expected. So we need a wrapper that implements__getitem__and has ashapeattribute.Thats it! We can now modify the
fitmethod to match the sklearn style, but in this case instead ofXbeing an array, it will either be a tuple (the result returned by the__getitem__method) or an instance of ourDataHandlerclass.Now
GridSearchCVwill work as expected by just passing an instance of aDataHandlercontaining theXandYarrays.