Skip to content

lee1026td/Pure-Java-Neural-Network

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Pure Java Neural Network Project

Environments

  • Java JDK 21

About Project

In this project, I implemented a simple form of Artificial Neural Network (ANN) using pure Java only.

Supported Algorithms (Not Finished yet)

Activation Functions

  • Identity
  • ReLU (Rectified Linear Unit)
  • ELU (Exponential Linear Unit)
  • LeakyReLU
  • Sigmoid
  • Tanh (Hyperbolic Tangent)
  • Softmax

Loss Functions

  • MSELoss
  • BinaryCrossEntropyLoss

Optimizer

  • Stochastic Gradient Descent (Including SGD Momentum)
  • RMSProp
  • AdaGrad
  • Adam

Metric

  • Binary accuracy

Scalar

  • Min-Max Scalar

Usage (Example)

public class Main {

    public static void main(String[] args) {

        CSVReader reader = new CSVReader("G:\\Datasets\\moon_dataset.csv");
        double[][] arr = reader.readCSV();

        Network nn = new Network(
                new Layer(2, 4, new LeakyReLU(), Initializer.InitType.HE),
                new Layer(4, 6, new LeakyReLU(), Initializer.InitType.HE),
                new Layer(6, 6, new LeakyReLU(), Initializer.InitType.HE),
                new Layer(6, 4, new LeakyReLU(), Initializer.InitType.HE),
                new Layer(4, 1, new Sigmoid(), Initializer.InitType.XAVIER)
        );

        nn.compile(new AdamOptimizer(0.01, 0.9, 0.99, 1e-8), new BinaryCrossEntropyLoss(), new BinaryAccuracy());

        Matrix datasets = new Matrix(arr);

        /* Select columns to divide by data and label values */
        Matrix X = datasets.getColumnTo(0, 1).transpose();
        Matrix Y = datasets.getColumnTo(2, 2);

        /* To normalize your input dataset */
        //Matrix X_norm = MinMaxScalar.minMaxScalar(X);

        nn.train(X, Y, 10000, 100);  // train set, labels, epochs, log steps
        Matrix res = nn.predict(X); // Prediction

        System.out.println("Prediction Accuracy : " + nn.getAccuracy(res, Y));
    }
}

Tested Datasets

Example Results (Using Python pyplot)

  • "make_moons" Image

TODO

  • Add comments
  • Add Optimizers
  • Applying Dropouts
  • Regularization
  • Gradient Clipping
  • Batch train

About

Deep Neural Network Design using Pure Java

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages