I'm trying to overfit a network over simple data. The data I'm working with is the MNIST image dataset, 60000 training images of size 784 pixels.

What I want to do is a form of phase retrieval. I took the MNIST dataset and performed 2 variable fourier transform on it. This transformed the 60000 by 784 real matrix into a 60000 by 784 complex matrix.

Finally then, I took the absolute value of each number, and put it in a new 60000 by 784 real matrix, called amplitudes, and i also took the angle (or phase) of each number and put it in a 60000 by 784 matrix of real numbers called phases.

The goal is to predict the phases given the amplitudes.

this is the extremely simple code

from keras.models import Sequential
from keras.layers import Dense
import numpy as np

def normalize_angles(phases):
    phases = phases + np.pi
    phases /= (2 * np.pi)
    return phases

def build_fourier_mnist():
    mnist = np.load("train_features.npy") #MNIST as is.
    fourier_mnist = np.zeros(mnist.shape, dtype=np.complex)
    for i in range(mnist.shape[0]):
        current_image = np.reshape(mnist[i, :], (28, 28)) #Turn to matrix so we can perform 2d fft
        fourier_current_image = np.fft.fft2(current_image) #perform 2d fft
        fourier_mnist[i, :] = np.reshape(fourier_current_image,(1, 784)) #flatten and save to new matrix
    return fourier_mnist

fourier_mnist = build_fourier_mnist()
amplitudes = np.abs(fourier_mnist)
phases = normalize_angles(np.angle(fourier_mnist))

model = Sequential()
model.add(Dense(784, input_dim=amplitudes.shape[1], activation='sigmoid'))
model.add(Dense(784, activation='sigmoid'))
model.add(Dense(784, activation='sigmoid'))
model.add(Dense(784, activation='sigmoid'))
model.add(Dense(phases.shape[1], activation='sigmoid'))
#Compile model
model.compile(loss='mean_squared_error', optimizer='adam')
#Fit the model
model.fit(amplitudes, phases, epochs=400, batch_size=100)

It does "work" in the sense that it gives a pretty ok result, but I'm trying to massively overfit it. The best training error I got was around 0.083. Can you get any better? How can I improve this?

0 Answers