Skip to content

Commit

Permalink
limit the change (#67)
Browse files Browse the repository at this point in the history
Co-authored-by: Nigel@SudioMac <[email protected]>
  • Loading branch information
nleck and nigelleck authored Mar 11, 2023
1 parent af2a608 commit f7a8500
Show file tree
Hide file tree
Showing 13 changed files with 207 additions and 137 deletions.
1 change: 0 additions & 1 deletion src/architecture/ConnectionInterfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ export interface ConnectionTrace extends ConnectionExport {
trace: {
used?: boolean;
eligibility?: number;
previousDeltaWeight?: number;
totalDeltaWeight?: number;
};
}
46 changes: 26 additions & 20 deletions src/architecture/Network.ts
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ export class Network implements NetworkInternal {
/**
* Back propagate the network
*/
propagate(rate: number, momentum: number, update: boolean, target: number[]) {
propagate(rate: number, target: number[]) {
if (
target === undefined || target.length !== this.output
) {
Expand All @@ -921,8 +921,6 @@ export class Network implements NetworkInternal {
const n = this.nodes[i] as Node;
n.propagate(
rate,
momentum,
update,
target[--targetIndex],
);
}
Expand All @@ -934,7 +932,21 @@ export class Network implements NetworkInternal {
i--
) {
const n = this.nodes[i] as Node;
n.propagate(rate, momentum, update);
n.propagate(rate);
}
}

/**
* Back propagate the network
*/
propagateUpdate() {
for (
let indx = this.nodes.length - 1;
indx >= this.input;
indx--
) {
const n = this.nodes[indx] as Node;
n.propagateUpdate();
}
}

Expand Down Expand Up @@ -1188,10 +1200,7 @@ export class Network implements NetworkInternal {
const targetError = options.error || 0.05;
const cost = Costs.find(options.cost ? options.cost : "MSE");
const baseRate = options.rate == undefined ? Math.random() : options.rate;
const momentum = options.momentum == undefined
? Math.random()
: options.momentum;
const batchSize = options.batchSize || 1; // online learning

const ratePolicyName = options.ratePolicy
? options.ratePolicy
: randomPolicyName();
Expand Down Expand Up @@ -1240,7 +1249,10 @@ export class Network implements NetworkInternal {
throw "Set size must be positive";
}
const len = json.length;

const batchSize = Math.max(
options.batchSize ? options.batchSize : Math.round(len / 10),
1,
);
for (let i = len; i--;) {
const data = json[i];

Expand All @@ -1254,7 +1266,10 @@ export class Network implements NetworkInternal {

errorSum += cost.calculate(data.output, output);

this.propagate(currentRate, momentum, update, data.output);
this.propagate(currentRate, data.output);
if (update) {
this.propagateUpdate();
}
/* Clear if we've updated the state batch only */
if (update && (i || j)) {
/* Hold the last one so we can write it out */
Expand Down Expand Up @@ -1282,8 +1297,6 @@ export class Network implements NetworkInternal {
currentRate,
"policy",
yellow(ratePolicyName),
"momentum",
momentum,
);
}

Expand Down Expand Up @@ -2101,7 +2114,6 @@ export class Network implements NetworkInternal {
errorResponsibility: ns ? ns.errorResponsibility : undefined,
derivative: ns ? ns.derivative : undefined,
totalDeltaBias: ns ? ns.totalDeltaBias : undefined,
previousDeltaBias: ns ? ns.previousDeltaBias : undefined,
};
traceNodes[exportIndex] = traceNode as NodeTrace;
exportIndex++;
Expand All @@ -2115,7 +2127,6 @@ export class Network implements NetworkInternal {
exportConnection.trace = {
used: cs.used,
eligibility: cs.eligibility,
previousDeltaWeight: cs.previousDeltaWeight,
totalDeltaWeight: cs.totalDeltaWeight,
};

Expand Down Expand Up @@ -2198,9 +2209,6 @@ export class Network implements NetworkInternal {
: 0;
ns.derivative = trace.derivative ? trace.derivative : 0;
ns.totalDeltaBias = trace.totalDeltaBias ? trace.totalDeltaBias : 0;
ns.previousDeltaBias = trace.previousDeltaBias
? trace.previousDeltaBias
: 0;
}
uuidMap.set(n.uuid, pos);

Expand Down Expand Up @@ -2230,9 +2238,7 @@ export class Network implements NetworkInternal {
const trace = (conn as ConnectionTrace).trace;
cs.used = trace.used;
cs.eligibility = trace.eligibility ? trace.eligibility : 0;
cs.previousDeltaWeight = trace.previousDeltaWeight
? trace.previousDeltaWeight
: 0;

cs.totalDeltaWeight = trace.totalDeltaWeight
? trace.totalDeltaWeight
: 0;
Expand Down
2 changes: 2 additions & 0 deletions src/architecture/NetworkState.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ class NodeState {
public derivative: number;
public totalDeltaBias: number;
public previousDeltaBias: number;
public batchSize: number;

constructor() {
this.errorResponsibility = 0;
this.errorProjected = 0;
this.derivative = 0;
this.totalDeltaBias = 0;
this.previousDeltaBias = 0;
this.batchSize = 0;
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/architecture/NetworkUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ export class NetworkUtil {
}
const json = JSON.parse(JSON.stringify(creature.internalJSON()));
json.nodes.forEach(
(n: { uuid?: string; index?: number }) => {
(n: { uuid?: string; trace?: unknown }) => {
delete n.uuid;
delete n.trace;
},
);
json.connections.forEach(
Expand Down
147 changes: 77 additions & 70 deletions src/architecture/Node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ export class Node implements TagsInterface, NodeInternal {
const squashMethod = this.findSquash();

if (this.isNodeActivation(squashMethod)) {
activation = squashMethod.activate(this) + this.bias;
const squashActivation = squashMethod.activate(this);
activation = squashActivation + this.bias;
} else {
const toList = this.network.toConnections(this.index);
let value = this.bias;
Expand Down Expand Up @@ -233,7 +234,6 @@ export class Node implements TagsInterface, NodeInternal {

for (let i = 0; i < toList.length; i++) {
const c = toList[i];
// const fromState = this.network.networkState.node(c.from);
const fromActivation = this.network.getActivation(c.from);
const cs = this.network.networkState.connection(c.from, c.to);
if (self) {
Expand Down Expand Up @@ -346,22 +346,74 @@ export class Node implements TagsInterface, NodeInternal {
return activation;
}

private limit(delta: number, limit: number) {
const limitedDelta = Math.min(
Math.max(delta, Math.abs(limit) * -1),
Math.abs(limit),
);

return limitedDelta;
}

propagateUpdate() {
const ns = this.network.networkState.node(this.index);
const toList = this.network.toConnections(this.index);
for (let i = toList.length; i--;) {
const c = toList[i];

const cs = this.network.networkState.connection(c.from, c.to);

c.weight += this.limit(cs.totalDeltaWeight, 0.1);
if (!Number.isFinite(c.weight)) {
if (c.weight === Number.POSITIVE_INFINITY) {
c.weight = Number.MAX_SAFE_INTEGER;
} else if (c.weight === Number.NEGATIVE_INFINITY) {
c.weight = Number.MIN_SAFE_INTEGER;
} else if (isNaN(c.weight)) {
c.weight = 0;
} else {
console.trace();
throw c.from + ":" + c.to + ") invalid weight: " + c.weight;
}
}

cs.previousDeltaWeight = this.limit(cs.totalDeltaWeight, 0.1);
cs.totalDeltaWeight = 0;
}

const deltaBias = ns.totalDeltaBias / ns.batchSize;

this.bias += deltaBias;
if (!Number.isFinite(this.bias)) {
if (this.bias === Number.POSITIVE_INFINITY) {
this.bias = Number.MAX_SAFE_INTEGER;
} else if (this.bias === Number.NEGATIVE_INFINITY) {
this.bias = Number.MIN_SAFE_INTEGER;
} else if (isNaN(this.bias)) {
this.bias = 0;
} else {
console.trace();
throw this.index + ") invalid this.bias: " + this.bias;
}
}

ns.totalDeltaBias = 0;
ns.batchSize = 0;
}
/**
* Back-propagate the error, aka learn
*/
propagate(rate: number, momentum: number, update: boolean, target?: number) {
// momentum = momentum || 0;
// rate = rate || 0.3;

propagate(rate: number, target?: number) {
// Error accumulator
let error = 0;

const ns = this.network.networkState.node(this.index);
// const sp = this.network.networkState.nodePersistent(this.index);

// Output nodes get their error from the environment
if (this.type === "output") {
const activation = this.network.getActivation(this.index);
ns.errorResponsibility = ns.errorProjected = (target ? target : 0) -
this.network.getActivation(this.index);
activation;
} else { // the rest of the nodes compute their error responsibilities by back propagation
// error responsibilities from all the connections projected from this node
const fromList = this.network.fromConnections(this.index);
Expand All @@ -370,16 +422,14 @@ export class Node implements TagsInterface, NodeInternal {
const c = fromList[i];

const toState = this.network.networkState.node(c.to);
// Eq. 21
// const cs = this.network.networkState.connection(c.from, c.to);

const tmpError = error +
toState.errorResponsibility * c.weight;
error = Number.isFinite(tmpError) ? tmpError : error;
}

// Projected error responsibility
ns.errorProjected = ns.derivative * error;

if (!Number.isFinite(ns.errorProjected)) {
if (ns.errorProjected === Number.POSITIVE_INFINITY) {
ns.errorProjected = Number.MAX_SAFE_INTEGER;
Expand All @@ -389,80 +439,37 @@ export class Node implements TagsInterface, NodeInternal {
ns.errorProjected = 0;
} else {
console.trace();
// console.info(state.error, this.derivative, error);

throw this.index + ") invalid error.projected: " + ns.errorProjected;
}
}

// Error responsibilities from all connections gated by this neuron
error = 0;

// Error responsibility
ns.errorResponsibility = ns.errorProjected;
}

if (this.type === "constant") {
return;
}
if (this.type !== "constant") {
// Adjust all the node's incoming connections
const toList = this.network.toConnections(this.index);
for (let i = toList.length; i--;) {
const c = toList[i];

// Adjust all the node's incoming connections
const toList = this.network.toConnections(this.index);
for (let i = toList.length; i--;) {
const c = toList[i];
const cs = this.network.networkState.connection(c.from, c.to);

const cs = this.network.networkState.connection(c.from, c.to);
// const csp = this.network.networkState.connectionPersistent(c.from, c.to);
const gradient = ns.errorProjected * cs.eligibility;

// Adjust weight
const deltaWeight = rate * gradient;

cs.totalDeltaWeight += deltaWeight;
if (update) {
cs.totalDeltaWeight += momentum *
cs.previousDeltaWeight;
c.weight += cs.totalDeltaWeight;
if (!Number.isFinite(c.weight)) {
if (c.weight === Number.POSITIVE_INFINITY) {
c.weight = Number.MAX_SAFE_INTEGER;
} else if (c.weight === Number.NEGATIVE_INFINITY) {
c.weight = Number.MIN_SAFE_INTEGER;
} else if (isNaN(c.weight)) {
c.weight = 0;
} else {
console.trace();
throw c.from + ":" + c.to + ") invalid weight: " + c.weight;
}
}
const gradient = ns.errorProjected * cs.eligibility;

cs.previousDeltaWeight = cs.totalDeltaWeight;
cs.totalDeltaWeight = 0;
}
}
// Adjust weight
const deltaWeight = rate * gradient;

// Adjust bias
const deltaBias = rate * ns.errorResponsibility;
ns.totalDeltaBias += deltaBias;
if (update) {
ns.totalDeltaBias += momentum * ns.previousDeltaBias;

this.bias += ns.totalDeltaBias;
if (!Number.isFinite(this.bias)) {
if (this.bias === Number.POSITIVE_INFINITY) {
this.bias = Number.MAX_SAFE_INTEGER;
} else if (this.bias === Number.NEGATIVE_INFINITY) {
this.bias = Number.MIN_SAFE_INTEGER;
} else if (isNaN(this.bias)) {
this.bias = 0;
} else {
console.trace();
throw this.index + ") invalid this.bias: " + this.bias;
}
cs.totalDeltaWeight += deltaWeight;
}

ns.previousDeltaBias = ns.totalDeltaBias;
ns.totalDeltaBias = 0;
// Adjust bias
const deltaBias = rate * ns.errorResponsibility;
ns.totalDeltaBias += deltaBias;
}

ns.batchSize++;
}

/**
Expand Down
1 change: 0 additions & 1 deletion src/architecture/NodeInterfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,5 @@ export interface NodeTrace extends NodeExport {
errorProjected?: number;
derivative?: number;
totalDeltaBias?: number;
previousDeltaBias?: number;
};
}
6 changes: 0 additions & 6 deletions src/config/TrainOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@ export interface TrainOptions {
/** Sets the amount of iterations the process will maximally run, even when the target error has not been reached. Default: NaN */
iterations?: number;

// /** If set to true, will clear the network after every activation. This is useful for training LSTM's, more importantly for timeseries prediction. Default: false */
// clear?: boolean;

/** Sets the momentum of the weight change. More info here. Default: 0 */
momentum?: number;

/** Sets the rate policy for your training. This allows your rate to be dynamic, see the rate policies page. Default: methods.rate.FIXED() */
ratePolicy?: string;

Expand Down
Loading

0 comments on commit f7a8500

Please sign in to comment.