I'm trying to overfit a network over simple data. The data I'm working with is the MNIST image dataset, 60000 training images of size 784 pixels.
What I want to do is a form of phase retrieval. I took the MNIST dataset and performed 2 variable fourier transform on it. This transformed the 60000 by 784 real matrix into a 60000 by 784 complex matrix.
Finally then, I took the absolute value of each number, and put it in a new 60000 by 784 real matrix, called amplitudes, and i also took the angle (or phase) of each number and put it in a 60000 by 784 matrix of real numbers called phases.
The goal is to predict the phases given the amplitudes.
this is the extremely simple code
from keras.models import Sequential from keras.layers import Dense import numpy as np def normalize_angles(phases): phases = phases + np.pi phases /= (2 * np.pi) return phases def build_fourier_mnist(): mnist = np.load("train_features.npy") #MNIST as is. fourier_mnist = np.zeros(mnist.shape, dtype=np.complex) for i in range(mnist.shape): current_image = np.reshape(mnist[i, :], (28, 28)) #Turn to matrix so we can perform 2d fft fourier_current_image = np.fft.fft2(current_image) #perform 2d fft fourier_mnist[i, :] = np.reshape(fourier_current_image,(1, 784)) #flatten and save to new matrix return fourier_mnist fourier_mnist = build_fourier_mnist() amplitudes = np.abs(fourier_mnist) phases = normalize_angles(np.angle(fourier_mnist)) model = Sequential() model.add(Dense(784, input_dim=amplitudes.shape, activation='sigmoid')) model.add(Dense(784, activation='sigmoid')) model.add(Dense(784, activation='sigmoid')) model.add(Dense(784, activation='sigmoid')) model.add(Dense(phases.shape, activation='sigmoid')) #Compile model model.compile(loss='mean_squared_error', optimizer='adam') #Fit the model model.fit(amplitudes, phases, epochs=400, batch_size=100) model.save("phase_retriever2.h5")
It does "work" in the sense that it gives a pretty ok result, but I'm trying to massively overfit it. The best training error I got was around 0.083. Can you get any better? How can I improve this?