Skip to content

Commit

Permalink
#56v2 (#57)
Browse files Browse the repository at this point in the history
* Working on trace

* Working on trace

---------

Co-authored-by: Nigel@SudioMac <[email protected]>
  • Loading branch information
nleck and nigelleck authored Mar 5, 2023
1 parent aa2bf47 commit 43b9d24
Show file tree
Hide file tree
Showing 44 changed files with 228 additions and 78 deletions.
23 changes: 19 additions & 4 deletions src/Neat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { addTag, getTag, removeTag } from "../src/tags/TagsInterface.ts";
import { fineTuneImprovement } from "./architecture/FineTune.ts";
import { makeElitists } from "../src/architecture/elitism.ts";
import { Network } from "./architecture/Network.ts";
import { ensureDirSync } from "https://deno.land/std@0.170.0/fs/ensure_dir.ts";
import { ensureDirSync } from "https://deno.land/std@0.177.0/fs/ensure_dir.ts";
import { Mutation } from "./methods/mutation.ts";
import { Selection } from "./methods/Selection.ts";
import { Offspring } from "./architecture/Offspring.ts";
Expand Down Expand Up @@ -63,13 +63,14 @@ export class Neat {
const trainPromises = [];
for (
let i = 0;
i < this.population.length && i < this.workers.length;
i < this.population.length && i < Math.max(1, this.workers.length / 2);
i++
) {
const n = this.population[i];
if (n.score) {
const trained = getTag(n, "trained");
if (trained !== "YES") {
if (trained !== "YES" && i == 0) {
// console.info( `train ${n.uuid}`);
const p = this.workers[i].train(n, this.trainRate);
trainPromises.push(p);
addTag(n, "trained", "YES");
Expand Down Expand Up @@ -255,7 +256,7 @@ export class Neat {

const trainPopulation: Network[] = [];

await Promise.all(trainPromises).then((results) => {
await Promise.all(trainPromises).then(async (results) => {
for (let i = results.length; i--;) {
const r = results[i];
if (r.train) {
Expand All @@ -267,6 +268,20 @@ export class Neat {
// addTag(json, "duration", r.duration);

trainPopulation.push(Network.fromJSON(json, this.config.debug));
if (this.config.trainStore) {
if (r.train.trace) {
// Deno.writeTextFileSync( ".hack.json", JSON.stringify( JSON.parse( r.train.trace), null, 2));
const traceNetwork = Network.fromJSON(
JSON.parse(r.train.trace),
);
await NetworkUtil.makeUUID(traceNetwork);

Deno.writeTextFileSync(
`${this.config.trainStore}/${traceNetwork.uuid}.json`,
JSON.stringify(traceNetwork.traceJSON(), null, 2),
);
}
}
}
} else {
throw "No train result";
Expand Down
7 changes: 7 additions & 0 deletions src/architecture/ConnectionInterfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,10 @@ export interface ConnectionExport extends ConnectionCommon {
toUUID: string;
gaterUUID?: string;
}

export interface ConnectionTrace extends ConnectionExport {
trace: {
used: boolean;
// eligibility: number
};
}
86 changes: 63 additions & 23 deletions src/architecture/Network.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,28 @@ import { TagInterface } from "../tags/TagInterface.ts";
import {
ConnectionExport,
ConnectionInternal,
ConnectionTrace,
} from "./ConnectionInterfaces.ts";
import { NodeExport, NodeInternal } from "./NodeInterfaces.ts";
import { NetworkExport, NetworkInternal } from "./NetworkInterfaces.ts";
import {
NetworkExport,
NetworkInternal,
NetworkTrace,
} from "./NetworkInterfaces.ts";

import { DataRecordInterface } from "./DataSet.ts";
import { make as makeConfig } from "../config/NeatConfig.ts";
import { NeatOptions } from "../config/NeatOptions.ts";

import { yellow } from "https://deno.land/std@0.170.0/fmt/colors.ts";
import { yellow } from "https://deno.land/std@0.177.0/fmt/colors.ts";
import { WorkerHandler } from "../multithreading/workers/WorkerHandler.ts";
import { Neat } from "../Neat.ts";
import { getTag } from "../tags/TagsInterface.ts";
import { makeDataDir } from "../architecture/DataSet.ts";

import { TrainOptions } from "../config/TrainOptions.ts";
import { findRatePolicy, randomPolicyName } from "../config.ts";
import { emptyDirSync } from "https://deno.land/std@0.170.0/fs/empty_dir.ts";
import { emptyDirSync } from "https://deno.land/std@0.177.0/fs/empty_dir.ts";
import { Mutation } from "../methods/mutation.ts";
import { Node } from "../architecture/Node.ts";
import { Connection } from "./Connection.ts";
Expand Down Expand Up @@ -74,10 +79,10 @@ export class Network implements NetworkInternal {

/* Dispose of the network and all held memory */
public dispose() {
this.clear();
this.clearState();
this.clearCache();
this.connections = [];
this.nodes = [];
this.connections.length = 0;
this.nodes.length = 0;
}

public clearCache() {
Expand Down Expand Up @@ -195,7 +200,7 @@ export class Network implements NetworkInternal {
/**
* Clear the context of the network
*/
clear() {
clearState() {
this.networkState.clear();
}

Expand Down Expand Up @@ -998,7 +1003,6 @@ export class Network implements NetworkInternal {
options: NeatOptions,
) {
const config = makeConfig(options);
// Read the options

const start = Date.now();

Expand Down Expand Up @@ -1260,11 +1264,7 @@ export class Network implements NetworkInternal {
let iteration = 0;
let error = 1;
const EMPTY = { input: [], output: [] };
while (
Number.isFinite(error) &&
error > targetError &&
(iterations === 0 || iteration < iterations)
) {
while (true) {
iteration++;

// Update the rate
Expand Down Expand Up @@ -1305,18 +1305,23 @@ export class Network implements NetworkInternal {
/* Not cached so we can release memory as we go */
json[i] = EMPTY;
}
const update = (i + 1) % batchSize === 0 || i === 0;
const update = (i + 1) % batchSize === 0 || (i === 0 && j == 0);

const output = this.activate(data.input);

errorSum += cost.calculate(data.output, output);

this.propagate(currentRate, momentum, update, data.output);
/* Clear if we've updated the state batch only */
if (update && (i || j)) {
/* Hold the last one so we can write it out */
this.clearState();
}
}

counter += len;
}
this.applyLearnings();

error = errorSum / counter;

if (
Expand All @@ -1332,21 +1337,33 @@ export class Network implements NetworkInternal {
error,
"rate",
currentRate,
"clear",
options.clear ? true : false,
// "clear",
// options.clear ? true : false,
"policy",
yellow(ratePolicyName),
"momentum",
momentum,
);
}
}

if (options.clear) this.clear();
if (
Number.isFinite(error) &&
error > targetError &&
(iterations === 0 || iteration < iterations)
) {
this.applyLearnings();
this.clearState();
} else {
const traceJSON = this.traceJSON();
this.applyLearnings();
this.clearState();

return {
error: error,
};
return {
error: error,
trace: traceJSON,
};
}
}
}

/**
Expand Down Expand Up @@ -2212,6 +2229,24 @@ export class Network implements NetworkInternal {
return json;
}

traceJSON(): NetworkTrace {
const json = this.exportJSON();

const traceConnections = Array<ConnectionTrace>(json.connections.length);
this.connections.forEach((c, indx) => {
const exportConnection = json.connections[indx] as ConnectionTrace;
const cs = this.networkState.connection(c.from, c.to);
exportConnection.trace = {
used: cs.xTrace.used,
};

traceConnections[indx] = exportConnection;
});
json.connections = traceConnections;

return json as NetworkTrace;
}

internalJSON() {
if (this.DEBUG) {
this.validate();
Expand Down Expand Up @@ -2254,6 +2289,8 @@ export class Network implements NetworkInternal {
this.tags = [...json.tags];
}

this.clearState();

const uuidMap = new Map<string, number>();
this.nodes = new Array(json.nodes.length);
for (let i = json.input; i--;) {
Expand Down Expand Up @@ -2294,6 +2331,10 @@ export class Network implements NetworkInternal {
conn.weight,
conn.type,
);
if ((conn as ConnectionTrace).trace) {
const cs = this.networkState.connection(connection.from, connection.to);
cs.xTrace.used = (conn as ConnectionTrace).trace.used;
}

const gater = (conn as ConnectionInternal).gater;
if (Number.isFinite(gater)) {
Expand All @@ -2307,7 +2348,6 @@ export class Network implements NetworkInternal {
}

this.clearCache();
this.clear();

if (validate) {
this.validate();
Expand Down
7 changes: 7 additions & 0 deletions src/architecture/NetworkInterfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { TagsInterface } from "../tags/TagsInterface.ts";
import {
ConnectionExport,
ConnectionInternal,
ConnectionTrace,
} from "./ConnectionInterfaces.ts";
import { NodeExport, NodeInternal } from "./NodeInterfaces.ts";

Expand All @@ -26,3 +27,9 @@ export interface NetworkExport extends NetworkCommon {

nodes: NodeExport[];
}

export interface NetworkTrace extends NetworkExport {
connections: ConnectionTrace[];

nodes: NodeExport[];
}
7 changes: 6 additions & 1 deletion src/architecture/NetworkUtils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { generate as generateV5 } from "https://deno.land/std@0.170.0/uuid/v5.ts";
import { generate as generateV5 } from "https://deno.land/std@0.177.0/uuid/v5.ts";
import { Network } from "./Network.ts";

export class NetworkUtil {
Expand All @@ -20,6 +20,11 @@ export class NetworkUtil {
delete n.uuid;
},
);
json.connections.forEach(
(c: { trace?: { used: boolean }; index?: number }) => {
delete c.trace;
},
);
delete json.tags;
delete json.uuid;
delete json.score;
Expand Down
2 changes: 1 addition & 1 deletion src/architecture/Offspring.ts
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ export class Offspring {
});
}
}
offspring.clear();
offspring.clearState();

connectionList.forEach((c) => {
if (offspring.getConnection(c.from, c.to) == null) {
Expand Down
15 changes: 3 additions & 12 deletions src/config/NeatConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,7 @@ import { NeatOptions } from "./NeatOptions.ts";
import { Mutation, MutationInterface } from "../methods/mutation.ts";
import { Selection, SelectionInterface } from "../methods/Selection.ts";

export interface NeatConfig {
clear: boolean;
/** The directory to store the creatures (optional) */
creatureStore?: string;

/** The directory to store the experiments (optional) */
experimentStore?: string;

export interface NeatConfig extends NeatOptions {
/** List of creatures to start with */
creatures: NetworkInternal[] | NetworkExport[];

Expand Down Expand Up @@ -44,9 +37,6 @@ export interface NeatConfig {

growth: number;

/** Once the number of minutes are reached exit the loop. */
timeoutMinutes?: number;

/** Tne maximum number of connections */
maxConns: number;

Expand Down Expand Up @@ -81,7 +71,7 @@ export function make(parameters?: NeatOptions) {
const options = parameters || {};

const config: NeatConfig = {
clear: options.clear || false,
// clear: options.clear || false,

creatureStore: options.creatureStore,
experimentStore: options.experimentStore,
Expand Down Expand Up @@ -129,6 +119,7 @@ export function make(parameters?: NeatOptions) {
),
),
timeoutMinutes: options.timeoutMinutes,
trainStore: options.trainStore,
trainRate: options.trainRate ? options.trainRate : 0.01,

log: options.log ? options.log : 0,
Expand Down
5 changes: 4 additions & 1 deletion src/config/NeatOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export interface NeatOptions {
/** Target error 0 to 1 */
error?: number;

clear?: boolean;
// clear?: boolean;

costName?: string;

Expand Down Expand Up @@ -50,6 +50,9 @@ export interface NeatOptions {
/** Once the number of minutes are reached exit the loop. */
timeoutMinutes?: number;

/** The directory to store the trained networks (optional) */
trainStore?: string;

/** Tne maximum number of connections */
maxConns?: number;

Expand Down
4 changes: 2 additions & 2 deletions src/config/TrainOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ export interface TrainOptions {
/** Sets the amount of iterations the process will maximally run, even when the target error has not been reached. Default: NaN */
iterations?: number;

/** If set to true, will clear the network after every activation. This is useful for training LSTM's, more importantly for timeseries prediction. Default: false */
clear?: boolean;
// /** If set to true, will clear the network after every activation. This is useful for training LSTM's, more importantly for timeseries prediction. Default: false */
// clear?: boolean;

/** Sets the momentum of the weight change. More info here. Default: 0 */
momentum?: number;
Expand Down
1 change: 1 addition & 0 deletions src/multithreading/workers/WorkerHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export interface ResponseData {
train?: {
network: string;
error: number;
trace: string;
};
echo?: {
message: string;
Expand Down
1 change: 1 addition & 0 deletions src/multithreading/workers/WorkerProcessor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ export class WorkerProcessor {
train: {
network: json,
error: result.error,
trace: JSON.stringify(result.trace),
},
};
} else if (data.echo) {
Expand Down
Loading

0 comments on commit 43b9d24

Please sign in to comment.