Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast get synapse #326

Merged
merged 6 commits into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion quality.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set -e

deno fmt src test bench mod.ts
deno lint src test bench mod.ts
rm -rf .trace .test
rm -rf .trace .test .coverage
deno test \
--allow-read \
--allow-write \
Expand Down
141 changes: 56 additions & 85 deletions src/Creature.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,16 @@ export class Creature implements CreatureInternal {
this.neurons.length = 0;
}

public clearCache() {
this.cacheTo.clear();
this.cacheFrom.clear();
this.cacheSelf.clear();
public clearCache(from = -1, to = -1) {
if (from == -1 || to == -1) {
this.cacheTo.clear();
this.cacheFrom.clear();
this.cacheSelf.clear();
} else {
this.cacheTo.delete(to);
this.cacheFrom.delete(from);
this.cacheSelf.delete(from);
}
}

private initialize(options: {
Expand Down Expand Up @@ -730,9 +736,9 @@ export class Creature implements CreatureInternal {
let results = this.cacheTo.get(toIndx);
if (results === undefined) {
results = [];
const tmpList = this.synapses;
for (let i = tmpList.length; i--;) {
const c = tmpList[i];

for (let i = this.synapses.length; i--;) {
const c = this.synapses[i];

if (c.to === toIndx) results.push(c);
}
Expand All @@ -755,38 +761,57 @@ export class Creature implements CreatureInternal {
outwardConnections(fromIndx: number): Synapse[] {
let results = this.cacheFrom.get(fromIndx);
if (results === undefined) {
results = [];
const tmpList = this.synapses;
for (let i = tmpList.length; i--;) {
const c = tmpList[i];

if (c.from === fromIndx) results.push(c);
const startIndex = this.binarySearchForStartIndex(fromIndx);

if (startIndex !== -1) {
results = [];
for (let i = startIndex; i < this.synapses.length; i++) {
const tmp = this.synapses[i];
if (tmp.from === fromIndx) {
results.push(tmp);
} else {
break; // Since it's sorted, no need to continue once 'from' changes
}
}
} else {
results = []; // No connections found
}

this.cacheFrom.set(fromIndx, results);
}
return results;
}

getSynapse(from: number, to: number): Synapse | null {
if (Number.isInteger(from) == false || from < 0) {
throw new Error("FROM should be a non-negative integer was: " + from);
}
private binarySearchForStartIndex(fromIndx: number): number {
let low = 0;
let high = this.synapses.length - 1;
let result = -1; // Default to -1 if not found

if (Number.isInteger(to) == false || to < 0) {
throw new Error("TO should be a non-negative integer was: " + to);
while (low <= high) {
const mid = Math.floor((low + high) / 2);
const midValue = this.synapses[mid];

if (midValue.from < fromIndx) {
low = mid + 1;
} else if (midValue.from > fromIndx) {
high = mid - 1;
} else {
result = mid; // Found a matching 'from', but need the first occurrence
high = mid - 1; // Look left to find the first match
}
}

for (let pos = this.synapses.length; pos--;) {
const c = this.synapses[pos];
return result;
}

if (c.from == from) {
if (c.to == to) {
return c;
} else if (c.to < to) {
break;
}
} else if (c.from < from) {
getSynapse(from: number, to: number): Synapse | null {
const outwardConnections = this.outwardConnections(from);

for (let indx = outwardConnections.length; indx--;) {
const c = outwardConnections[indx];
if (c.to == to) {
return c;
} else if (c.to < to) {
break;
}
}
Expand All @@ -803,50 +828,6 @@ export class Creature implements CreatureInternal {
weight: number,
type?: "positive" | "negative" | "condition",
): Synapse {
if (Number.isInteger(from) == false || from < 0) {
throw new Error("from should be a non-negative integer was: " + from);
}

if (Number.isInteger(to) == false || to < 0) {
throw new Error("to should be a non-negative integer was: " + to);
}

if (to < this.input) {
throw new Error(
"to should not be pointed to any input neurons(" +
this.input + "): " + to,
);
}

if (to < from) {
throw new Error("to: " + to + " should not be less than from: " + from);
}

if (typeof weight !== "number") {
if (this.DEBUG) {
this.DEBUG = false;
console.warn(
JSON.stringify(this.exportJSON(), null, 2),
);

this.DEBUG = true;
}

throw new Error(from + ":" + to + ") weight not a number was: " + weight);
}

const toNeuron = this.neurons[to];
if (toNeuron) {
const toType = toNeuron.type;
if (toType == "constant" || toType == "input") {
throw new Error(`Can not connect ${from}->${to} with type ${toType}`);
}
} else {
throw new Error(
`Can't connect to index: ${to} of length: ${this.neurons.length}`,
);
}

const connection = new Synapse(
from,
to,
Expand Down Expand Up @@ -887,7 +868,7 @@ export class Creature implements CreatureInternal {
this.synapses.push(connection);
}

this.clearCache();
this.clearCache(from, to);

return connection;
}
Expand All @@ -896,14 +877,6 @@ export class Creature implements CreatureInternal {
* Disconnects the from neuron from the to node
*/
disconnect(from: number, to: number) {
if (Number.isInteger(from) == false || from < 0) {
throw new Error("from should be a non-negative integer was: " + from);
}
if (Number.isInteger(to) == false || to < 0) {
throw new Error("to should be a non-negative integer was: " + to);
}

// Delete the connection in the creature's connection array
const connections = this.synapses;

let found = false;
Expand All @@ -912,15 +885,13 @@ export class Creature implements CreatureInternal {
if (connection.from === from && connection.to === to) {
found = true;
connections.splice(i, 1);
this.clearCache();
this.clearCache(from, to);

break;
}
}

if (!found) {
throw new Error("No connection from: " + from + ", to: " + to);
}
assert(found, "Can't disconnect");
}

async applyLearnings(config: BackPropagationConfig) {
Expand Down
8 changes: 0 additions & 8 deletions src/architecture/Neuron.ts
Original file line number Diff line number Diff line change
Expand Up @@ -616,14 +616,6 @@ export class Neuron implements TagsInterface, NeuronInternal {
return c != null;
}

/**
* Checks if the given node is projecting to this node
*/
isProjectedBy(node: Neuron) {
const c = this.creature.getSynapse(node.index, this.index);
return c != null;
}

/**
* Converts the node to a json object
*/
Expand Down
25 changes: 13 additions & 12 deletions test/Constants/Constant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,30 +50,31 @@ Deno.test("Constants", () => {
input: 1,
output: 1,
};
const network = Creature.fromJSON(json);
network.validate();
const creature = Creature.fromJSON(json);
creature.validate();

for (let i = 100; i--;) {
network.modBias();
network.addConnection();
creature.modBias();
creature.addConnection();
}

network.validate();
Creature.fromJSON(network.exportJSON());
creature.validate();
Creature.fromJSON(creature.exportJSON());
assert(
Math.abs(network.neurons[1].bias) - 0.5 <
Math.abs(creature.neurons[1].bias) - 0.5 <
0.00001,
"Should NOT have changed the constant node was: " + network.neurons[1].bias,
"Should NOT have changed the constant node was: " +
creature.neurons[1].bias,
);

assert(
(network.neurons[2].bias) > 0.60001 ||
(network.neurons[2].bias) < 0.59999,
"Should have changed the hidden node was: " + network.neurons[2].bias,
(creature.neurons[2].bias) > 0.60001 ||
(creature.neurons[2].bias) < 0.59999,
"Should have changed the hidden node was: " + creature.neurons[2].bias,
);

assert(
network.inwardConnections(1).length === 0,
creature.inwardConnections(1).length === 0,
"Should not have any inward connections",
);
});
4 changes: 0 additions & 4 deletions test/Projection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,4 @@ Deno.test("projection", () => {
const flag3to0 = (outNode as Neuron).isProjectingTo(inNode0 as Neuron);

assert(!flag3to0, "3 -> 0 should not be associated");

const project3by0 = (outNode as Neuron).isProjectedBy(inNode0 as Neuron);

assert(project3by0, "3 is projected by 0");
});