2
\$\begingroup\$

I have recently tried to get a better grip on machine learning from a point of implementation - not statistics. I've read several explanations of an implementation of a Neural Network via pseudocode and this is the result - a Toy Neural Network. I have used several sources from medium.com and towardsdatascience.com (if all sources need to be listed, I will make an edit.)

I originally created a naive Matrix custom class with O(N^3) matrix multiplication, but removed to instead use ujmp. I used ujmp.org Matrix class for faster matrix multiplication, but due to my lacking understanding of how to utilise the speed ups, I believe that

This is the final code. Please comment and suggest improvements! Thank you. I will include SGD, backpropagation, feed forward, and mini batch calculation. Backprop is private due to a wrapping method called train.

The class NetworkInput is a wrapper for an attributes and a label DenseMatrix.

All functions here are interfaces for activationfunctions, functions to evaluate the test data and to calculate loss and error.

This is the SGD.

/**
 * Provides an implementation of SGD for this neural network.
 *
 * @param training  a Collections object with {@link NetworkInput }objects,
 *                  NetworkInput.getData() is the data, NetworkInput.getLabel()is the label.
 * @param test      a Collections object with {@link NetworkInput} objects,
 *                  NetworkInput.getData() is the data, NetworkInput.getLabel is the label.
 * @param epochs    how many iterations are we doing SGD for
 * @param batchSize how big is the batch size, typically 32. See https://stats.stackexchange.com/q/326663
 */
public void stochasticGradientDescent(@NotNull List<NetworkInput> training,
    @NotNull List<NetworkInput> test,
    int epochs,
    int batchSize) {

    int trDataSize = training.size();
    int teDataSize = test.size();

    for (int i = 0; i < epochs; i++) {
        // Randomize training sample.
        Collections.shuffle(training);

        System.out.println("Calculating epoch: " + (i + 1) + ".");

        // Do backpropagation.
        for (int j = 0; j < trDataSize - batchSize; j += batchSize) {
            calculateMiniBatch(training.subList(j, j + batchSize));
        }

        // Feed forward the test data
        List<NetworkInput> feedForwardData = this.feedForwardData(test);

        // Evaluate prediction with the interface EvaluationFunction.
        int correct = this.evaluationFunction.evaluatePrediction(feedForwardData).intValue();
        // Calculate loss with the interface ErrorFunction
        double loss = errorFunction.calculateCostFunction(feedForwardData);

        // Add the plotting data, x, y_1, y_2 to the global
        // lists of xValues, correctValues, lossValues.
        addPlotData(i, correct, loss);

        System.out.println("Loss: " + loss);
        System.out.println("Epoch " + (i + 1) + ": " + correct + "/" + teDataSize);

        // Lower learning rate each iteration?. Might implement? Don't know how to.
        // ADAM? Is that here? Are they different algorithms all together?
        // TODO: Implement Adam, RMSProp, Momentum?
        // this.learningRate = i % 10 == 0 ? this.learningRate / 4 : this.learningRate;
    }

}

Here we calculate the mini batches and update our weights with an average.

private void calculateMiniBatch(List<NetworkInput> subList) {
    int size = subList.size();

    double scaleFactor = this.learningRate / size;

    DenseMatrix[] dB = new DenseMatrix[this.totalLayers - 1];
    DenseMatrix[] dW = new DenseMatrix[this.totalLayers - 1];
    for (int i = 0; i < this.totalLayers - 1; i++) {
        DenseMatrix bias = getBias(i);
        DenseMatrix weight = getWeight(i);
        dB[i] = Matrix.Factory.zeros(bias.getRowCount(), bias.getColumnCount());
        dW[i] = Matrix.Factory
            .zeros(weight.getRowCount(), weight.getColumnCount());
    }

    for (NetworkInput data : subList) {
        DenseMatrix dataIn = data.getData();
        DenseMatrix label = data.getLabel();
        List<DenseMatrix[]> deltas = backPropagate(dataIn, label);
        DenseMatrix[] deltaB = deltas.get(0);
        DenseMatrix[] deltaW = deltas.get(1);

        for (int j = 0; j < this.totalLayers - 1; j++) {
            dB[j] = (DenseMatrix) dB[j].plus(deltaB[j]);
            dW[j] = (DenseMatrix) dW[j].plus(deltaW[j]);
        }
    }

    for (int i = 0; i < this.totalLayers - 1; i++) {
        DenseMatrix cW = getWeight(i);
        DenseMatrix cB = getBias(i);

        DenseMatrix scaledDeltaB = (DenseMatrix) dB[i].times(scaleFactor);
        DenseMatrix scaledDeltaW = (DenseMatrix) dW[i].times(scaleFactor);

        DenseMatrix nW = (DenseMatrix) cW.minus(scaledDeltaW);
        DenseMatrix nB = (DenseMatrix) cB.minus(scaledDeltaB);

        setWeight(i, nW);
        setLayerBias(i, nB);
    }
}

This is the back propagation algorithm.

private List<DenseMatrix[]> backPropagate(DenseMatrix toPredict, DenseMatrix correct) {
    List<DenseMatrix[]> totalDeltas = new ArrayList<>();

    DenseMatrix[] weights = getWeights();
    DenseMatrix[] biases = getBiasesAsMatrices();

    DenseMatrix[] deltaBiases = this.initializeDeltas(biases);
    DenseMatrix[] deltaWeights = this.initializeDeltas(weights);

    // Perform Feed Forward here...
    List<DenseMatrix> activations = new ArrayList<>();
    List<DenseMatrix> xVector = new ArrayList<>();

    // Alters all arrays and lists.
    this.backPropFeedForward(toPredict, activations, xVector, weights, biases);
    // End feedforward

    // Calculate error signal for last layer
    DenseMatrix deltaError;

    // Applies the error function to the last layer, create
    deltaError = errorFunction
        .applyErrorFunctionGradient(activations.get(activations.size() - 1), correct);

    // Set the deltas to the error signals of bias and weight.
    deltaBiases[deltaBiases.length - 1] = deltaError;
    deltaWeights[deltaWeights.length - 1] = (DenseMatrix) deltaError
        .mtimes(activations.get(activations.size() - 2).transpose());

    // Now iteratively apply the rule
    for (int k = deltaBiases.length - 2; k >= 0; k--) {
        DenseMatrix z = xVector.get(k);
        DenseMatrix differentiate = functions[k + 1].applyDerivative(z);

        deltaError = (DenseMatrix) weights[k + 1].transpose().mtimes(deltaError)
            .times(differentiate);

        deltaBiases[k] = deltaError;
        deltaWeights[k] = (DenseMatrix) deltaError.mtimes(activations.get(k).transpose());
    }

    totalDeltas.add(deltaBiases);
    totalDeltas.add(deltaWeights);

    return totalDeltas;
}

EDIT I forgot to include the feed forward algorithm.

private void backPropFeedForward(DenseMatrix starter, List<DenseMatrix> actives,
    List<DenseMatrix> vectors,
    DenseMatrix[] weights, DenseMatrix[] biases) {
    DenseMatrix toPredict = starter;
    //actives.add(toPredict);
    actives.add(Matrix.Factory.zeros(starter.getRowCount(), starter.getColumnCount()));
    for (int i = 0; i < getTotalLayers() - 1; i++) {
        DenseMatrix x = (DenseMatrix) weights[i].mtimes(toPredict).plus(biases[i]);
        vectors.add(x);

        toPredict = this.functions[i + 1].applyFunction(x);
        actives.add(toPredict);
    }
}
\$\endgroup\$

1 Answer 1

3
\$\begingroup\$

Ooh, hard to make any comments about this without knowing neural networks.

However, I do see a lot of repetition especially when it comes to bias and weight. It may be a good idea to create generic methods for those. I see you use specific methods such as getBias(i) and getWeight(i) but those can be inserted using lambda functions.

The size and scaleFactor variables are not used in the first two for loops, so I don't understand why they are declared & initialized so early. If you only declare variables where you need them it becomes easier to extract methods out of large swaths of code, and your code becomes easier to read (because you don't have to keep track of so many variables as a reader).

There are a lot of unexplained calculations such as - 1 and - 2 going on. For you they may be clear, but generally you should comment on what you're trying to achieve with them.

In general the functions are too large. Try to minimize the amount of code. If you have three for loops in a row in one function, try and see if you can extract (private) methods for them instead.

stochasticGradientDescent clearly prints out the result instead of returning it. That's not nice; at least indicate somewhere that it produces output. Instead of using System.out, simply use a PrintStream out as argument if you create such a method. Then you can always stream the output to file or to a String (for testing purposes) - and for console output you just pass System.out as parameter.

Similarly, calculateMiniBatch doesn't return a value, it calls two setters instead. That's generally not done, as you can directly assign such things to fields. Calling public methods from private methods can be dangerous if they get overwritten. For this kind of purpose I might also consider returning a private WeightAndBias class instance with just two fields (a record in Java).

I'm really wondering why DenseMatrix is not parameterized properly, I keep seeing class casts back to DenseMatrix while the methods are clearly defined on DenseMatrix itself. That probably means that an interface does not have a generic type parameter included that is set not by DenseMatrix, e.g.

interface Matrix<T extends Matrix> {
    T operation();
}

class DenseMatrix<DenseMatrix> {
    DenseMatrix operation();
}

Otherwise, I'll be glad to let you know that I don't understand the first thing about the code, so I'll stop while I'm behind :)

\$\endgroup\$

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.