Skip to content

Commit

Permalink
Increased coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
mghosh00 committed Apr 6, 2024
1 parent 71cd662 commit 8a6730d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
13 changes: 6 additions & 7 deletions javaNeuralNetwork/src/main/java/neural_network/util/Plotter.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import java.awt.*;
import java.io.IOException;
import java.nio.file.FileAlreadyExistsException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
Expand Down Expand Up @@ -46,9 +45,9 @@ public static void datapointScatter(NavigableMap<Header, List<String>> df, Strin
// Create chart instance if it has not been set already
if (chart == null) {
chart = new XYChartBuilder().width(800).height(600)
.title("%s %s for %s data".formatted(actualOrPredicted, valuesOrClasses, phase))
.xAxisTitle("x1").yAxisTitle("x2").build();
}
chart.setTitle("%s %s for %s data".formatted(actualOrPredicted, valuesOrClasses, phase));
chart.getStyler().setMarkerSize(5);
if (! regression) {
NavigableMap<String, List<List<Double>>> dataByCategory = organiseDataByCategory(
Expand Down Expand Up @@ -124,9 +123,9 @@ public static void comparisonScatter(NavigableMap<Header, List<Double>> df,
// Create chart instance if it is null
if (chart == null) {
chart = new XYChartBuilder().width(800).height(600)
.title("Comparison scatter plot for %s data".formatted(phase))
.xAxisTitle("Actual").yAxisTitle("Predicted").build();
}
chart.setTitle("Comparison scatter plot for %s data".formatted(phase));
chart.getStyler().setMarkerSize(5);

// Add the line y = x as we want to get as close to this line as possible
Expand Down Expand Up @@ -155,7 +154,7 @@ public static void comparisonScatter(NavigableMap<Header, List<Double>> df,
BitmapEncoder.saveBitmap(chart, dirName + phase + "/comparison" + subString,
BitmapEncoder.BitmapFormat.PNG);
} catch (IOException e) {
throw new RuntimeException("Invalid title:" + title);
throw new RuntimeException("Invalid title: " + title);
}

// Reset the chart and wrappedChart
Expand Down Expand Up @@ -200,7 +199,7 @@ public static void plotLoss(Map<String, List<Double>> lossDf, String title) thro
BitmapEncoder.saveBitmap(chart, dirName + "/losses" + subString,
BitmapEncoder.BitmapFormat.PNG);
} catch (IOException e) {
throw new RuntimeException("Invalid title:" + title);
throw new RuntimeException("Invalid title: " + title);
}

chart = null;
Expand Down Expand Up @@ -247,9 +246,9 @@ static XYChart getChart() {
return chart;
}

/** Setter for wrapperChart. Used for mocking.
/** Setter for wrappedChart. Used for mocking.
*
* @return The wrapperChart instance from {@code org.knowm.xchart}.
* @return The wrappedChart instance from {@code org.knowm.xchart}.
*/
static SwingWrapper<XYChart> getWrappedChart() {
return wrappedChart;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ void datapointScatter() throws IOException {
Plotter.setChart(mockChart);
Plotter.setWrappedChart(mockWrappedChart);
Plotter.datapointScatter(scatterDf, "training", "test_title", false);
verify(mockChart, times(1))
.setTitle("Predicted classes for training data");
verify(mockChart, times(1))
.getStyler();
verify(mockStyler, times(1))
Expand Down Expand Up @@ -117,6 +119,16 @@ void organiseDataByCategory() {
}
}

@Test
void comparisonScatterErroneous() throws IOException {
Exception exception = assertThrows(RuntimeException.class,
() -> Plotter.comparisonScatter(regScatterDf, "good_phase",
"bad_title/"));
assertEquals("Invalid title: bad_title/",
exception.getMessage());
Files.delete(Path.of(dirName + "good_phase"));
}

@Test
void comparisonScatter() throws IOException {
Plotter.setShowPlots(true);
Expand All @@ -134,6 +146,8 @@ void comparisonScatter() throws IOException {
Plotter.setChart(mockChart);
Plotter.setWrappedChart(mockWrappedChart);
Plotter.comparisonScatter(regScatterDf, "training", "test_title");
verify(mockChart, times(1))
.setTitle("Comparison scatter plot for training data");
verify(mockChart, times(1))
.getStyler();
verify(mockStyler, times(1))
Expand Down Expand Up @@ -166,6 +180,15 @@ void comparisonScatterNoMocks() throws IOException {
assertNull(Plotter.getWrappedChart());
}


@Test
void plotLossErroneous() throws IOException {
Exception exception = assertThrows(RuntimeException.class,
() -> Plotter.plotLoss(lossDf,"bad_title/"));
assertEquals("Invalid title: bad_title/",
exception.getMessage());
}

@Test
void plotLoss() throws IOException {
Plotter.setShowPlots(true);
Expand Down

0 comments on commit 8a6730d

Please sign in to comment.