I'm working on a project that involves implementing some algorithms as python classes and testing their performance. I decided to write them up as sklearn estimators so that I could use GridSearchCV
for validation.
However, one of my algorithms for Inductive Matrix Completion takes more than just X
and y
as arguments. This becomes a problem for the GridSearchCV.fit
as there appears to be no way to pass more than just X
and y
to the fit method of the estimator. The source shows the following arguments for GridSearchCV.fit
:
def fit(self, X, y=None, groups=None, **fit_params):
And of course the downstream methods expect only these two arguments. Obviously it would be no trivial task (or advisable) to modify my local copy of GridSearchCV
to accommodate my needs.
For reference IMC basically states that $ R \approx XW^THY^T $. So my fit method takes the following form:
def fit(self, R, X, Y):
So trying the following fails as the Y value never gets passed to the IMC.fit
method:
imc = IMC()
params = {...}
gs = GridSearchCV(imc, param_grid=params)
gs.fit(R, X, Y)
I've created a workaround for this by modifying the IMC.fit
method like so (this also has to be inserted into the score
method):
def fit(self, R, X, Y=None):
if Y is None:
split = np.where(np.all(X == 999, axis=0))[0][0]
Y = X[:, split + 1:]
X = X[:, :split]
...
This allows me to use numpy.hstack
to stack X and Y horizontally and insert a column of all 999
between them. This array can then be passed to GridSearchCV.fit
as follows:
data = np.hstack([X, np.ones((X.shape[0],1)) * 999, Y])
gs.fit(R, data)
This approach works, but feels pretty hacky. Therefore my question is this:
So after getting some inspiration from a friend on this (@Matthew Drury) I constructed a much more elegant solution.
Again the problem is framed as such:
I have a matrix completion method that takes
X
,Y
, andR
as arguments and attempts to constructW
andH
that minimizeR - XWHY
for all observed indices inR
. A basic implementation of afit
method would look like this:This doesn't fit well into the standard sklearn model where fit takes an
X
(the features that feed into a model) andy
(the results) and looks like this:This isn't really an issue until you start using
GridSearchCV
or other cross validation methods as they expect the data to fit the latter format. So to marry these two concepts I needed a way of packaging two disparate matricesX
andY
into a single structure without losing the separate nature of the two.In the 5 minutes I had to dedicate to this originally I came up with the hacky solution. In a matrix
R
shapen, m
where the rows correspond to the records inX
and the columns correspond to the records inY
, there areb
total entries. If we take the row and column indices for all of these entries and indexX
on the rows andY
on the columns we will end up with equal length matrices forX
andY
. These can then be stacked horizontally, separated by a column of nonsense, and passed to the cross validation methods without issue (we just need a couple helper methods inside the original class to reconstruct the originalX
andY
from the stack before fitting.The point of this question was to find the elegant solution, or preferably an existing solution. That doesn't seem to be the case so I will propose the following model for any future estimators/classifiers built inheriting from sklearn that require more than just a single feature matrix for the fit method.
Create a DataHandler
When using
GridSearchCV
thefit
method does a round of checks before firing off any calls to the estimatorsfit
method. One of these determines if the passedX
array is indexable. This test basically checks ifX
implements__getitem__
oriloc
and is the same length asy
. This length check requiresX
to have ashape
attribute. At that point the split indices and fits can be computed as expected. So we need a wrapper that implements__getitem__
and has ashape
attribute.Thats it! We can now modify the
fit
method to match the sklearn style, but in this case instead ofX
being an array, it will either be a tuple (the result returned by the__getitem__
method) or an instance of ourDataHandler
class.Now
GridSearchCV
will work as expected by just passing an instance of aDataHandler
containing theX
andY
arrays.