Skip to content

Commit

Permalink
added confidence level feature and fixed mobile version
Browse files Browse the repository at this point in the history
  • Loading branch information
RusFortunat committed Nov 15, 2024
1 parent 97949d5 commit bd81c79
Show file tree
Hide file tree
Showing 14 changed files with 417 additions and 227 deletions.
320 changes: 257 additions & 63 deletions spring.log

Large diffs are not rendered by default.

Binary file added spring.log.2024-11-13.0.gz
Binary file not shown.
32 changes: 10 additions & 22 deletions src/main/java/com/guessNumbersWithAI/controller/ViewController.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,13 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.stereotype.Controller;
import org.springframework.ui.Model;
import org.springframework.util.ResourceUtils;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.*;

import java.io.File;
import java.io.FileWriter;
import java.io.InputStream;
import java.util.ArrayList;

@Controller
public class ViewController {
Expand All @@ -32,44 +25,39 @@ public class ViewController {
public String hangleGetMapping(Model theModel){

NeuralNetwork ourNeuralNetwork = new NeuralNetwork();
DrawnImages drawnImages = new DrawnImages();

theModel.addAttribute("NeuralNetwork", ourNeuralNetwork);
theModel.addAttribute("DrawnImages", drawnImages);

return "main-view";
}

@PostMapping("/write-number")
@PostMapping(value = "/getAnswerFromServer")
public String hanglePostMapping(@ModelAttribute("NeuralNetwork") NeuralNetwork ourNeuralNetwork,
@ModelAttribute("DrawnImages") DrawnImages drawnImages){
@RequestBody String rawImageData){

try{

// create input vector by processing raw pixel image data; we turn 280x280 px image into 28x28 px image
DrawnImages drawnImages = new DrawnImages(rawImageData);
double[] inputVector = drawnImages.processRawInput();

// load network parameters from the file
/*Resource resource=resourceLoader.getResource(
"classpath:net_params_size784_256_10_lr0.001_trainEps100.txt");*/
Resource resource= resourceLoader.getResource(
"classpath:net_params_size784_256_10_lr0.001_trainEps100.txt");
InputStream inputStream = resource.getInputStream();
ourNeuralNetwork.loadNetworkParameters(inputStream);

// create input vector by processing raw pixel image data; we turn 280x280 px image into 28x28 px image
double[] inputVector = drawnImages.processRawInput();

// pass input vector to network and get the prediction; the answer will be displayed on the html view page
ourNeuralNetwork.forward(inputVector);
System.out.println("Answer: " + ourNeuralNetwork.getAnswer());

// save image to the H2 database for future neural networks training

// save image to the H2 database for future training of our neural networks
drawnImages.saveImageToDB();

}catch(Exception e){
logger.error(e.getMessage());
System.out.println(e.getMessage());

}

return "main-view";
return "fragment/ajaxPart :: returnedAnswer";
}
}
22 changes: 7 additions & 15 deletions src/main/java/com/guessNumbersWithAI/model/DrawnImages.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@

public class DrawnImages {

private ArrayList<Double> rawPixelInput;
private double[] rawPixelInput;
private double[] inputVector;

public DrawnImages(){
this.rawPixelInput = new ArrayList<>();
public DrawnImages(String rawImageData){

String[] rawData = rawImageData.split(",");
rawPixelInput = Arrays.stream(rawData).mapToDouble(Double::parseDouble).toArray();

inputVector = new double[28*28]; // MNIST training dataset consists of 28x28 pixel images.
}


// compress 280x280 pixel image to 28x28 one
public double[] processRawInput(){

Expand All @@ -32,7 +34,7 @@ public double[] processRawInput(){
// new pixel is an average of 100 pixels in 10x10 square
int index = 280 * (10 * Y + y) + 10 * X + x;

avergePixelValue += rawPixelInput.get(index) / 100.0;
avergePixelValue += rawPixelInput[index] / 100.0;
}
}

Expand Down Expand Up @@ -122,14 +124,4 @@ private Connection createConnectionAndEnsureDatabase() throws SQLException {

return conn;
}

// getters and setters

public ArrayList<Double> getRawPixelInput() {
return rawPixelInput;
}

public void setRawPixelInput(ArrayList<Double> rawPixelInput) {
this.rawPixelInput = rawPixelInput;
}
}
73 changes: 48 additions & 25 deletions src/main/java/com/guessNumbersWithAI/model/NeuralNetwork.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
// I will not be training the neural networks here, but will simply upload their parameters from file.
// The first trained neural network model was obtained by me writing the whole supervised learning
// machinery in Java from scratch, and the second was obtained with use of PyTorch machine learning package.
// If you are interested in how I trained these two neural networks, check out this project:
// The trained neural network model was obtained by me writing the whole supervised learning machinery in Java
// from scratch. If you are interested in how I trained these two neural networks, check out this project:
// My implementation from scratch in Java: https://github.com/RusFortunat/java_ML_library
// Python PyTorch and my C++ implementation: https://github.com/RusFortunat/alternative-ML-lib-C2plus

package com.guessNumbersWithAI.model;

Expand All @@ -13,12 +11,10 @@

public class NeuralNetwork {

// I have two trained networks, and this parameter determines which of them the user will use
public String chosenNetworkModel;

private int inputSize;
private int hiddenSize;
private int outputSize;
private double[] outputVector;

// the neural network parameters that will be loaded from the file
private double[][] firstLayerWeights;
Expand All @@ -33,11 +29,11 @@ public class NeuralNetwork {
// I will let the user to choose which neural network to use and load parameters later
public NeuralNetwork() {

this.chosenNetworkModel = "";

this.inputSize = 28*28; // MNIST training images are of 28x28 pixel size
this.hiddenSize = 256; // arbitrary, should not be too small or too big
this.outputSize = 10; // 0-9 digits that network will be guessing
this.outputVector = new double[outputSize];

this.firstLayerWeights = new double[hiddenSize][inputSize];
this.firstLayerBiases = new double[hiddenSize];
this.secondLayerWeights = new double[outputSize][hiddenSize];
Expand All @@ -55,7 +51,6 @@ public void forward(double[] input) throws RuntimeException{
// 3. Repeat 1.-2. for the next layer with secondLayerWeights and secondLayerBiases to get the [output] vector.

double[] hiddenVector = new double[hiddenSize];
double[] outputVector = new double[outputSize];

// compute hidden activation values
for(int i = 0; i < hiddenSize; i++){
Expand All @@ -67,21 +62,18 @@ public void forward(double[] input) throws RuntimeException{
hiddenVector[i] = sum;
}
// compute output activations
double totalSum = 0.0;
double smallestValue = 0.0; // sum of activations can be negative
for(int i = 0; i < outputSize; i++){
double sum = 0;
for(int j = 0; j < hiddenSize; j++){
double activation = secondLayerWeights[i][j]*hiddenVector[j] + secondLayerBiases[i];
sum+= activation; // no relu on output values
}
outputVector[i] = sum;
totalSum += Math.exp(sum);
}

// normalize the output vector using the SoftMax approach;
// i will keep it for future, if i'll need the whole vector
for(int i = 0; i < outputSize; i++){
outputVector[i] = Math.exp(outputVector[i]) / totalSum;
if(sum < smallestValue){
smallestValue = sum;
}
}

// return the index (which also represents the number) of the output vector with the highest value
Expand All @@ -94,6 +86,37 @@ public void forward(double[] input) throws RuntimeException{
}
}

// if one of the values is negative, shift the entire vector
System.out.println("output vector before softmax:");
double totalSum = 0;
for(int i = 0; i < outputSize; i++){
if(smallestValue < 0) outputVector[i] += Math.abs(smallestValue);
totalSum += outputVector[i];
}

// normalize the output vector
for(int i = 0; i < outputSize; i++){
outputVector[i] =outputVector[i] / totalSum;
}

// now let's magnify the difference between the output values,
// so that we can have clear visual effect on html page
maxValue = outputVector[maxId];
double highConfidenceValue = 0.9*maxValue;
double mediumConfidenceValue = 0.5*maxValue;
double lowConfidenceValue = 0.3*maxValue;
for(int i = 0; i < outputSize; i++){
if(outputVector[i] > highConfidenceValue){
outputVector[i] = Math.min(1.0, 4*outputVector[i]); // opacity is capped at 1.0
}
else if(outputVector[i] > mediumConfidenceValue){
outputVector[i] = Math.min(1.0, 3*outputVector[i]); // opacity is capped at 1.0
}
else if(outputVector[i] > lowConfidenceValue){
outputVector[i] = Math.min(1.0, 3*outputVector[i]); // opacity is capped at 1.0
}
}

answer = maxId;
}

Expand Down Expand Up @@ -151,14 +174,6 @@ public void loadNetworkParameters(InputStream networkParamsFile) throws Exceptio

// getters and setters

public String getChosenNetworkModel() {
return chosenNetworkModel;
}

public void setChosenNetworkModel(String chosenNetworkModel) {
this.chosenNetworkModel = chosenNetworkModel;
}

public int getAnswer() {
return answer;
}
Expand All @@ -167,6 +182,14 @@ public void setAnswer(int answer) {
this.answer = answer;
}

public double[] getOutputVector() {
return outputVector;
}

public void setOutputVector(double[] outputVector) {
this.outputVector = outputVector;
}

// printers, for debug purposes
public void printNetworkParameteres(){
System.out.println("firstLayerWeights:");
Expand Down

This file was deleted.

This file was deleted.

Binary file modified src/main/resources/image-database.mv.db
Binary file not shown.
14 changes: 0 additions & 14 deletions src/main/resources/saved_params.txt

This file was deleted.

14 changes: 0 additions & 14 deletions src/main/resources/saved_tensor_params.txt

This file was deleted.

Binary file added src/main/resources/static/css/main-background.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion src/main/resources/static/css/style.css
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

body {
background-color: #343434;
background-image: url("main-background.png");
}

h1, h2, h3, h4 {
color: #fff;
font-family: Arial, Helvetica, sans-serif;
}
}
26 changes: 26 additions & 0 deletions src/main/resources/templates/fragment/ajaxPart.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<!DOCTYPE html>
<html lang="en" xmlns:th="http://www.thymeleaf.org">

<div th:fragment="returnedAnswer">

<h2 style=" margin-left:-15px; font-size: 40px"> Answer:
<span style="color: white; font-size: 40px" th:text="${NeuralNetwork.answer}"></span>
</h2>

<h2 style=" margin-left:-15px"> AI's confidence level:</h2>

<span th:id="sp0" style=" margin-left:-15px">0</span>
<span th:id="sp1">1</span>
<span th:id="sp2">2</span>
<span th:id="sp3">3</span>
<span th:id="sp4">4</span>
<span th:id="sp5">5</span>
<span th:id="sp6">6</span>
<span th:id="sp7">7</span>
<span th:id="sp8">8</span>
<span th:id="sp9">9</span>

<input type="hidden" th:id="outputVector" th:field="${NeuralNetwork.outputVector}"/>

</div>
</html>
Loading

0 comments on commit bd81c79

Please sign in to comment.