-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from DjDeveloperr/native-backend
feat: Native backend in C
- Loading branch information
Showing
56 changed files
with
3,018 additions
and
146 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,7 @@ | ||
.vscode/ | ||
.idea/ | ||
.idea/ | ||
build/ | ||
xor_model.bin | ||
node_modules/ | ||
digit_model.bin | ||
bench/tfjs/node_modules |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
export { CPU } from "../src/cpu/mod.ts"; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
export { GPU } from "../src/gpu/mod.ts"; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
export { Native } from "../src/native/mod.ts"; | ||
export { Matrix } from "../src/native/matrix.ts"; | ||
export type { DataType } from "../src/native/matrix.ts"; | ||
export type { Dataset } from "../src/native/backend.ts"; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import { NeuralNetwork, DenseLayer } from "../../mod.ts"; | ||
import { Native, Matrix } from "../../backends/native.ts"; | ||
|
||
const start = Date.now(); | ||
|
||
const network = await new NeuralNetwork({ | ||
input: 2, | ||
layers: [ | ||
new DenseLayer({ size: 3, activation: "sigmoid" }), | ||
new DenseLayer({ size: 1, activation: "sigmoid" }), | ||
], | ||
cost: "crossentropy", | ||
}).setupBackend(Native); | ||
|
||
network.train( | ||
[ | ||
{ | ||
inputs: Matrix.of([ | ||
[0, 0], | ||
[0, 1], | ||
[1, 0], | ||
[1, 1], | ||
]), | ||
outputs: Matrix.column([0, 1, 1, 0]), | ||
}, | ||
], | ||
5000, | ||
0.1, | ||
); | ||
|
||
console.log("training time", Date.now() - start); | ||
|
||
console.log( | ||
network.predict( | ||
Matrix.of([ | ||
[0, 0], | ||
[0, 1], | ||
[1, 0], | ||
[1, 1], | ||
]), | ||
), | ||
); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import tensorflow as tf | ||
import numpy as np | ||
import time | ||
|
||
|
||
start = time.time() | ||
|
||
#tf.get_logger().setLevel('INFO') | ||
#tf.autograph.set_verbosity(3) | ||
|
||
x_train, y_train = (tf.constant([[0,0],[0,1],[1,0],[1,1]], "float32"), tf.constant([[0],[1],[1],[0]], "float32")) | ||
XOR_True = [(1, 0), (0, 1)] | ||
XOR_False = [(0, 0), (1, 1)] | ||
|
||
model = tf.keras.models.Sequential([ | ||
tf.keras.layers.Flatten(input_shape=(2,)), | ||
tf.keras.layers.Dense(3, activation=tf.nn.sigmoid), # hidden layer | ||
tf.keras.layers.Dense(1, activation=tf.nn.sigmoid) # output layer | ||
]) | ||
|
||
model.compile( | ||
# optimizer='adam', | ||
loss='binary_crossentropy', | ||
#loss='mean_squared_error', # try this too; treat as regression problem | ||
metrics=['accuracy']) | ||
|
||
|
||
model.fit(x_train, y_train, epochs=5000, verbose=0) | ||
print("Training took", time.time() - start, "ms") | ||
|
||
print("XOR True") | ||
for x in XOR_True: | ||
print(model.predict(np.array([x]))) | ||
print("XOR False") | ||
for x in XOR_False: | ||
print(model.predict(np.array([x]))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"name": "bench", | ||
"version": "1.0.0", | ||
"description": "", | ||
"main": "xor.mjs", | ||
"scripts": { | ||
"start": "node xor.mjs" | ||
}, | ||
"keywords": [], | ||
"author": "", | ||
"license": "ISC", | ||
"dependencies": { | ||
"@tensorflow/tfjs-node": "^3.20.0" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// deno-lint-ignore-file | ||
import * as tf from '@tensorflow/tfjs-node'; | ||
|
||
const model = tf.sequential(); | ||
model.add(tf.layers.dense({inputShape:[2], units: 3, activation: 'sigmoid'})); | ||
model.add(tf.layers.dense({units: 1, activation: 'sigmoid'})); | ||
model.compile({ loss: "meanSquaredError", optimizer: "sgd" }); | ||
|
||
const xs = tf.tensor2d([ | ||
[0, 0], | ||
[0, 1], | ||
[1, 0], | ||
[1, 1], | ||
]); | ||
|
||
const ys = tf.tensor2d([ | ||
[0], | ||
[1], | ||
[1], | ||
[0], | ||
]); | ||
|
||
const start = Date.now(); | ||
model.fit(xs, ys, {epochs: 5000, verbose:0}).then(() => { | ||
console.log("Training took", Date.now() - start, "ms"); | ||
model.predict(xs).print(); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
{ | ||
"tasks": { | ||
"train_xor": "deno run -A --unstable ./examples/train_xor.ts", | ||
"train_xor": "deno run -A --unstable ./examples/train_xor_cpu.ts", | ||
"train_xor_gpu": "deno run -A --unstable ./examples/train_xor_gpu.ts", | ||
"train_letter": "deno run -A --unstable ./examples/train_letter.ts", | ||
"train_emoticon": "deno run -A --unstable ./examples/train_emoticon.ts", | ||
"train_conv": "deno run -A --unstable ./examples/train_conv.ts", | ||
"perf_test": "deno run -A --unstable ./examples/perf_test.ts" | ||
"perf_test": "deno run -A --unstable ./examples/perf_test.ts", | ||
"build": "cd native/build && cmake .. && make" | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.idx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import type { Dataset } from "../../backends/native.ts"; | ||
import { Matrix } from "../../backends/native.ts"; | ||
|
||
|
||
export function assert(condition: boolean, message?: string) { | ||
if (!condition) { | ||
throw new Error(message); | ||
} | ||
} | ||
|
||
export function loadDataset(imagesFile: string, labelsFile: string) { | ||
const images = Deno.readFileSync(new URL(imagesFile, import.meta.url)); | ||
const labels = Deno.readFileSync(new URL(labelsFile, import.meta.url)); | ||
|
||
const imageView = new DataView(images.buffer); | ||
const labelView = new DataView(labels.buffer); | ||
|
||
assert(imageView.getUint32(0) === 0x803, "Invalid image file"); | ||
assert(labelView.getUint32(0) === 0x801, "Invalid label file"); | ||
|
||
const count = imageView.getUint32(4); | ||
assert(count === labelView.getUint32(4), "Image and label count mismatch"); | ||
|
||
const results: Dataset[] = []; | ||
|
||
for (let i = 0; i < count; i++) { | ||
const inputs = new Float32Array(784); | ||
for (let j = 0; j < 784; j++) { | ||
inputs[j] = imageView.getUint8(16 + i * 784 + j) / 255; | ||
} | ||
|
||
const outputs = new Float32Array(10); | ||
outputs[labelView.getUint8(8 + i)] = 1; | ||
|
||
results.push({ | ||
inputs: new Matrix(1, inputs.length, inputs), | ||
outputs: new Matrix(1, outputs.length, outputs), | ||
}); | ||
} | ||
|
||
return results; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
async function download(url: string, to: string) { | ||
console.log("Download", url); | ||
const f = await Deno.open(new URL(to, import.meta.url), { | ||
write: true, | ||
create: true, | ||
}); | ||
await fetch(url).then((response) => { | ||
response.body!.pipeThrough(new DecompressionStream("gzip")).pipeTo( | ||
f.writable, | ||
); | ||
}); | ||
} | ||
|
||
await download( | ||
"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", | ||
"train-images.idx", | ||
); | ||
await download( | ||
"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", | ||
"train-labels.idx", | ||
); | ||
await download( | ||
"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", | ||
"test-images.idx", | ||
); | ||
await download( | ||
"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", | ||
"test-labels.idx", | ||
); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import { NativeBackend } from "../../src/native/backend.ts"; | ||
import { DataType, Matrix } from "../../backends/native.ts"; | ||
import { loadDataset } from "./common.ts"; | ||
|
||
const network = NativeBackend.load("digit_model.bin"); | ||
|
||
const testSet = loadDataset("test-images.idx", "test-labels.idx"); | ||
|
||
function argmax<T extends DataType>(mat: Matrix<T>) { | ||
let max = -Infinity; | ||
let index = -1; | ||
for (let i = 0; i < mat.data.length; i++) { | ||
if (mat.data[i] > max) { | ||
max = mat.data[i]; | ||
index = i; | ||
} | ||
} | ||
return index; | ||
} | ||
|
||
const correct = testSet.filter((e) => { | ||
const prediction = argmax(network.predict(e.inputs)); | ||
const expected = argmax(e.outputs); | ||
return prediction === expected; | ||
}); | ||
|
||
console.log(`${correct.length} / ${testSet.length} correct`); | ||
console.log(`accuracy: ${((correct.length / testSet.length) * 100).toFixed(2)}%`); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import { DenseLayer, NeuralNetwork } from "../../mod.ts"; | ||
import { Native } from "../../backends/native.ts"; | ||
import { loadDataset } from "./common.ts"; | ||
|
||
const network = await new NeuralNetwork({ | ||
input: 784, | ||
layers: [ | ||
new DenseLayer({ size: 28 * 2, activation: "sigmoid" }), | ||
new DenseLayer({ size: 10, activation: "sigmoid" }), | ||
], | ||
cost: "crossentropy", | ||
}).setupBackend(Native); | ||
|
||
console.log("Loading training dataset..."); | ||
const trainSet = loadDataset("train-images.idx", "train-labels.idx"); | ||
|
||
const epochs = 5; | ||
console.log("Training (" + epochs + " epochs)..."); | ||
const start = performance.now(); | ||
network.train(trainSet, epochs, 0.1); | ||
console.log("Training complete!", performance.now() - start); | ||
|
||
network.save("digit_model.bin"); |
Oops, something went wrong.