Recently, I started trying to train a neural network using backpropagation. The network structure is 784-512-10, and I used the Sigmoid activation function. When I tested a single-layer network on the MNIST dataset, I got around 90%. My results are around 86% with this multi-layer network, is this normal? Did I get the backpropagation part wrong?
Here is my code:
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import java.util.Scanner;
public class NeuralNetwork{
public static double learningRate = 0.01;
public static int epoch = 15;
public static int ROWS = 28;
public static int COLUMNS = 28;
public static int INPUT = ROWS * COLUMNS;
public static int outNum = 10;
public static int hiddenNum = 512;
public static double[][] weights2 = new double[outNum][hiddenNum];
public static double[] bias2 = new double[outNum];
public static double[][] weights1 = new double[hiddenNum][INPUT];
public static double[] bias1 = new double[outNum];
private static final double TRAININGSIZE = 10;
public static double[][] inputs = new double[outNum][INPUT];
private static final double[][] target = new double[outNum][outNum];
private static final ArrayList<String> filenames = new ArrayList<>();
private static final ArrayList<Integer> yetDone = new ArrayList<>();
public static double[] actual = new double[outNum];
public static Random rand = new SecureRandom();
public static Scanner input = new Scanner(System.in);
public static void main(String[]args) throws Exception {
System.out.println("1. Learn the network");
System.out.println("2. Guess a number");
System.out.println("3. Guess file");
System.out.println("4. Guess All Numbers");
System.out.println("5. Guess image");
switch (input.nextInt()){
case 1:
learn();
break;
case 2:
guess();
break;
case 3:
guessFile();
break;
case 4:
guessAll();
break;
}
}
public static void guessAll() throws IOException, ClassNotFoundException {
System.out.println("Recognizing...");
/*
for(int x = 1; x < 60000; x++){
filenames.add("data/" + String.format("%05d",x) + ".txt");
}
ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network.ser")));
Layers lay = (Layers) ois.readObject();
int correct = 0;
for (String z : filenames) {
double[] a = scan(z,0);
correct += getBestGuess(sigmoid(lay.step(a))) == actual[0] ? 1 : 0;
}
System.out.println("Training: " + correct + " / " + filenames.size() + " correct.");
filenames.clear();
*/
for(int x = 60000; x < 70000; x++){
filenames.add("data/" + String.format("%05d",x) + ".txt");
}
ObjectInputStream oiss = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network1.ser")));
Layers lays1 = (Layers) oiss.readObject();
ObjectInputStream oiss2 = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network2.ser")));
Layers lays2 = (Layers) oiss2.readObject();
int corrects = 0;
for (String z : filenames) {
double[] a = scan(z,0);
corrects += getBestGuess(sigmoid(lays2.step(sigmoid(lays1.step(a))))) == actual[0] ? 1 : 0;
}
System.out.println("Testing: " + corrects + " / " + filenames.size() + " correct.");
System.out.println("Done!");
}
public static void makeList(){
for(int index = 0; index < TRAININGSIZE; index++){
int indices = rand.nextInt(yetDone.size() - 1) + 1;
filenames.add("data/" + String.format("%05d",yetDone.get(indices)) + ".txt");
yetDone.remove(indices);
}
prepareData();
for(int indices = 0; indices < outNum; indices++) {
for(int index = 0; index < outNum; index++){
target[indices][index] = 0;
}
target[indices][(int)actual[indices]] = 1;
}
}
public static void prepareData(){
for(int index = 0; index < outNum; index++){
try {
inputs[index] = scan(filenames.get(index), index);
} catch (FileNotFoundException ex) {
ex.printStackTrace();
}
}
}
public static double[] scan(String filename, int index) throws FileNotFoundException {
Scanner in = new Scanner(new File(filename));
double[] a = new double[INPUT];
for(int i = 0; i < INPUT; i++){
a[i] = in.nextDouble() / 255;
}
actual[index] = in.nextDouble();
return a;
}
public static void guessFile() throws IOException, ClassNotFoundException {
System.out.print("Enter Filename: ");
double[] a = scan(input.next(), 0);
ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network.ser")));
Layers lay = (Layers) ois.readObject();
double[] results = lay.step(a);
System.out.println("This is a " + getBestGuess(sigmoid(results)) + "!");
System.out.println(Arrays.toString(results));
}
public static double guess(double[] a) throws IOException, ClassNotFoundException {
ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network.ser")));
Layers lay = (Layers) ois.readObject();
double[] results = lay.step(a);
return getBestGuess(sigmoid(results));
}
public static void guess() throws IOException, ClassNotFoundException {
ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network.ser")));
System.out.println("Input number: ");
Layers lay = (Layers) ois.readObject();
double[] a = new double[INPUT];
for(int index = 0; index < a.length; index++){
a[index] = input.nextInt();
}
double[] results = lay.step(a);
System.out.println("This is a " + getBestGuess(sigmoid(results)) + "!");
System.out.println(Arrays.toString(sigmoid(results)));
}
public static void learn() {
System.out.println("Learning...");
initialise(weights2, outNum, hiddenNum);
initialise(bias2);
initialise(weights1,hiddenNum, INPUT);
initialise(bias1);
Layers lay2 = new Layers(weights2, bias2, outNum, hiddenNum);
Layers lay1 = new Layers(weights1, bias1, hiddenNum, INPUT);
double[] result2 = new double[lay2.outNum];
double[] result1 = new double[lay1.outNum];
double[] a2;
double[] a1;
double cost = 0;
double sumFinal;
for(int x = 0; x < epoch; x++) {
yetDone.clear();
for(int y = 0; y < 60000; y++){
yetDone.add(y);
}
for (int ind = 0; ind < 200; ind++) {
filenames.clear();
makeList();
for (int n = 0; n < lay2.outNum; n++) {
a1 = inputs[n]; //number
result1 = sigmoid(lay1.step(a1));
a2 = result1;
result2 = sigmoid(lay2.step(a2));
for (int i = 0; i < lay2.outNum; i++) {
for (int j = 0; j < lay2.INPUT; j++) {
weights2[i][j] += learningRate * a2[j] * (target[n][i] - result2[i]);
cost += Math.pow((target[n][i] - result2[i]), 2);
}
}
for(int i = 0; i < lay1.outNum; i++){
for(int j = 0; j < lay1.INPUT; j++){
sumFinal = 0;
for(int k = 0; k < lay2.outNum; k++){
// weight * derivSigma(outputHiddenLayer) * 2(out - expected)
sumFinal += result1[k] * (1 - result1[k]) * 2 * (result2[k] - target[n][k]); // * weights2[k][i]
}
weights1[i][j] -= learningRate * a1[j] * sumFinal * result1[i] * (1 - result1[i]);
}
}
}
lay1.update(weights1, bias1);
lay2.update(weights2, bias2);
}
System.out.println("Epoch " + x + ": " + cost);
cost = 0;
}
System.out.println(Arrays.toString(result1));
System.out.println(Arrays.toString(result2));
for(double[] arr : inputs) {
System.out.println("This is a " + getBestGuess(sigmoid(lay2.step(sigmoid(lay1.step(arr))))) + "!");
}
try (ObjectOutputStream oos = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream("network1.ser")))) {
oos.writeObject(lay1);
} catch (IOException ex) {
ex.printStackTrace();
}
try (ObjectOutputStream oos = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream("network2.ser")))) {
oos.writeObject(lay2);
} catch (IOException ex) {
ex.printStackTrace();
}
System.out.println("Done! Saved to file.");
}
public static double sigmoid(double x){
return 1 / (1 + Math.exp(-x));
}
public static double[] sigmoid(double[] weights){
for(int index = 0; index < weights.length; index++){
weights[index] = sigmoid(weights[index]);
}
return weights;
}
public static void initialise(double[] bias){
Random random = new Random();
for(int index = 0; index < bias.length; index++){
bias[index] = random.nextGaussian();
}
}
public static void initialise(double[][] weights, int outNum, int INPUT){
Random random = new Random();
for(int index = 0; index < outNum; index++){
for(int indice = 0; indice < INPUT; indice++){
weights[index][indice] = random.nextGaussian();
}
}
}
public static int getBestGuess(double[] result){
double k = Integer.MIN_VALUE;
double index = 0;
int current = 0;
for(double a : result){
if(k < a){
k = a;
index = current;
}
current++;
}
return (int)index;
}
}
class Layers implements Serializable {
private static final long serialVersionUID = 8L;
double[][] weights;
double[] bias;
int outNum;
int INPUT;
public Layers(double[][] weights, double[] bias, int outNum, int INPUT){
this.weights = weights;
this.bias = bias;
this.outNum = outNum;
this.INPUT = INPUT;
}
public void update(double[][] weights, double[] bias){
this.weights = weights;
this.bias = bias;
}
public double[] step(double[] aa){
double[] out = new double[outNum];
for (int index = 0; index < outNum; index++) {
for (int indices = 0; indices < INPUT; indices++) {
out[index] += weights[index][indices] * aa[indices];
}
}
return out;
}
}
Thanks in advance!