forked from tensorflow/tfjs-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathui.js
129 lines (111 loc) · 4.13 KB
/
ui.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tfvis from '@tensorflow/tfjs-vis';
const statusElement = document.getElementById('status');
const messageElement = document.getElementById('message');
const imagesElement = document.getElementById('images');
export function logStatus(message) {
statusElement.innerText = message;
}
export function trainingLog(message) {
messageElement.innerText = `${message}\n`;
console.log(message);
}
export function showTestResults(batch, predictions, labels) {
const testExamples = batch.xs.shape[0];
imagesElement.innerHTML = '';
for (let i = 0; i < testExamples; i++) {
const image = batch.xs.slice([i, 0], [1, batch.xs.shape[1]]);
const div = document.createElement('div');
div.className = 'pred-container';
const canvas = document.createElement('canvas');
canvas.className = 'prediction-canvas';
draw(image.flatten(), canvas);
const pred = document.createElement('div');
const prediction = predictions[i];
const label = labels[i];
const correct = prediction === label;
pred.className = `pred ${(correct ? 'pred-correct' : 'pred-incorrect')}`;
pred.innerText = `pred: ${prediction}`;
div.appendChild(pred);
div.appendChild(canvas);
imagesElement.appendChild(div);
}
}
const lossLabelElement = document.getElementById('loss-label');
const accuracyLabelElement = document.getElementById('accuracy-label');
const lossValues = [[], []];
export function plotLoss(batch, loss, set) {
const series = set === 'train' ? 0 : 1;
lossValues[series].push({x: batch, y: loss});
const lossContainer = document.getElementById('loss-canvas');
tfvis.render.linechart(
lossContainer, {values: lossValues, series: ['train', 'validation']}, {
xLabel: 'Batch #',
yLabel: 'Loss',
width: 400,
height: 300,
});
lossLabelElement.innerText = `last loss: ${loss.toFixed(3)}`;
}
const accuracyValues = [[], []];
export function plotAccuracy(batch, accuracy, set) {
const accuracyContainer = document.getElementById('accuracy-canvas');
const series = set === 'train' ? 0 : 1;
accuracyValues[series].push({x: batch, y: accuracy});
tfvis.render.linechart(
accuracyContainer,
{values: accuracyValues, series: ['train', 'validation']}, {
xLabel: 'Batch #',
yLabel: 'Accuracy',
width: 400,
height: 300,
});
accuracyLabelElement.innerText =
`last accuracy: ${(accuracy * 100).toFixed(1)}%`;
}
export function draw(image, canvas) {
const [width, height] = [28, 28];
canvas.width = width;
canvas.height = height;
const ctx = canvas.getContext('2d');
const imageData = new ImageData(width, height);
const data = image.dataSync();
for (let i = 0; i < height * width; ++i) {
const j = i * 4;
imageData.data[j + 0] = data[i] * 255;
imageData.data[j + 1] = data[i] * 255;
imageData.data[j + 2] = data[i] * 255;
imageData.data[j + 3] = 255;
}
ctx.putImageData(imageData, 0, 0);
}
export function getModelTypeId() {
return document.getElementById('model-type').value;
}
export function getTrainEpochs() {
return Number.parseInt(document.getElementById('train-epochs').value);
}
export function setTrainButtonCallback(callback) {
const trainButton = document.getElementById('train');
const modelType = document.getElementById('model-type');
trainButton.addEventListener('click', () => {
trainButton.setAttribute('disabled', true);
modelType.setAttribute('disabled', true);
callback();
});
}