EN
JavaScript - Kohonen Neural Network (WTA Learning)
10
points
In this short article, we would like to show how to implement winner-takes-all learning (WTA) in JavaScript.
Practical example:
// ONLINE-RUNNER:browser;
// Common utils:
function detectArrayHeight(array) {
if (array == null) {
return 0;
}
return array.length;
}
function detectArrayWidth(array) {
if (array == null || array.length == 0) {
return 0;
}
var tmp = array[0];
if (tmp == null) {
return 0;
}
return tmp.length;
}
// Distance neuron:
// -- utils
function createDistanceNeuron(inputsCount, distance, createWeight) {
var weights = Array(inputsCount);
for (var i = 0; i < weights.length; ++i) {
weights[i] = createWeight(i);
}
return new DistanceNeuron(weights, distance);
}
function createZeroedNeuron(inputsCount, distance) {
return createDistanceNeuron(inputsCount, distance, function() {
return 0.0;
});
}
function createRandomNeuron(inputsCount, distance) {
return createDistanceNeuron(inputsCount, distance, function() {
return Math.random();
});
}
// -- distances
// Source: https://dirask.com/posts/JavaScript-Chebyshev-Distance-function-p5q6Rp
//
function calculateChebyshevDistance(a, b) {
if (a.length === 0 || a.length !== b.length) {
return NaN;
}
var max = Math.abs(a[0] - b[0]);
for (var i = 1; i < a.length; ++i) {
var tmp = Math.abs(a[i] - b[i]);
if (tmp > max) {
max = tmp;
}
}
return max;
}
function calculateEuclideanDistance(a, b) {
if (a.length === 0 || a.length !== b.length) {
return NaN;
}
var sum = 0.0;
for (var i = 0; i < a.length; ++i) {
var tmp = a[i] - b[i];
sum += tmp * tmp;
}
return Math.sqrt(sum);
}
// Source: https://dirask.com/posts/JavaScript-Rectilinear-Distance-function-1XgJgj
//
function calculateRectilinearDistance(a, b) {
if (a.length === 0 || a.length !== b.length) {
return NaN;
}
var sum = 0.0;
for (var i = 0; i < a.length; ++i) {
sum += Math.abs(a[ i ] - b[i]);
}
return sum;
}
// -- neuron
function DistanceNeuron(weights, distance) {
this.getWeights = function() {
return weights;
};
this.getDistance = function() {
return distance;
};
this.randomize$1 = function() {
this.randomize$2(0.0, 1.0);
};
this.randomize$2 = function(min, max) {
var range = max - min;
for (var i = 0; i < weights.length; ++i) {
weights[i] = min + range * Math.random();
}
};
this.compute = function(inputs) {
return distance.call(null, weights, inputs);
};
}
// Distance network:
// -- utils
function createDistanceNetwork(inputsCount, width, height, createNeuron) {
var neurons = Array(height);
for (var i = 0; i < neurons.length; ++i) {
var tmp = Array(width);
for (var j = 0; j < tmp.length; ++j) {
tmp[j] = createNeuron(i, j);
}
neurons[i] = tmp;
}
return new DistanceNetwork(inputsCount, neurons);
}
function createZeroedNetwork(inputsCount, width, height, distance) {
return createDistanceNetwork(inputsCount, width, height, function() {
return createZeroedNeuron(inputsCount, distance);
});
}
function createRandomNetwork(inputsCount, width, height, distance) {
return createDistanceNetwork(inputsCount, width, height, function() {
return createRandomNeuron(inputsCount, distance);
});
}
// -- network
function DistanceNetwork(inputsCount, neurons) {
var width = detectArrayWidth(neurons);
var height = detectArrayHeight(neurons);
var outputsCount = width * height;
this.getInputsCount = function() {
return inputsCount;
};
this.getOutputsCount = function() {
return outputsCount;
};
this.getWidth = function() {
return width;
};
this.getHeight = function() {
return height;
};
this.getNeurons = function() {
return neurons;
};
this.getNeuron$1 = function(position) {
return this.getNeuron$2(position.x, position.y);
};
this.getNeuron$2 = function(positionX, positionY) {
var tmp = neurons[positionY];
if (tmp) {
return tmp[positionX];
}
return null;
};
this.randomize$1 = function() {
for (var y = 0; y < height; ++y) {
var tmp = neurons[y];
for (var x = 0; x < width; ++x) {
tmp[x].random$1();
}
}
};
this.randomize$2 = function(min, max) {
for (var y = 0; y < height; ++y) {
var tmp = neurons[y];
for (var x = 0; x < width; ++x) {
tmp[x].randomize$2(min, max);
}
}
};
// Computes neurons outputs.
//
this.compute$1 = function(inputs) {
var index = -1;
var outputs = new Array(outputsCount);
for (var y = 0; y < height; ++y) {
var tmp = neurons[y];
for (var x = 0; x < width; ++x) {
outputs[++index] = tmp[x].compute(inputs);
}
}
return outputs;
}
// Computes neurons outputs for each data row.
//
this.compute$2 = function(data) {
var result = Array(data.length);
for (var i = 0; i < data.length; ++i) {
result[i] = this.compute$1(data[i]);
}
return result;
}
// Searches for the winning neuron and returns his position.
// Winning neuron is neuron that has weights closest to the inputs values (smallest distance).
//
this.search$1 = function(inputs) {
var positionX = -1;
var positionY = -1;
var index = -1;
var value = +Infinity;
var outputs = new Array(outputsCount);
for (var y = 0; y < height; ++y) {
var tmp = neurons[y];
for (var x = 0; x < width; ++x) {
var output = tmp[x].compute(inputs);
if (output < value) {
positionX = x;
positionY = y;
value = output;
}
outputs[++index] = output;
}
}
if (positionX == -1) {
return null;
}
return {
x: positionX,
y: positionY,
value: value
};
}
// Searches for the winning neuron of each data row and returns their positions.
//
this.search$2 = function(data) {
var result = Array(data.length);
for (var i = 0; i < data.length; ++i) {
result[i] = this.search$1(data[i]);
}
return result;
}
}
// WTA learning (winner-takes-all)
// -- utils
function runWTATraining(learning, data, epochsLimit, toleratedError, rateModel, radiusModel) {
var error = +Infinity;
for (var i = 0; i < epochsLimit && error > toleratedError; ++i) {
error = learning.runEpoch(data);
}
return error;
}
// -- learning
function WTALearning(network) {
var rate = 0.01; // learning rate
this.getRate = function() {
return rate;
};
this.setRate = function(value) {
rate = value;
};
function update(neuron, inputs) {
var weights = neuron.getWeights();
var error = 0.0;
for (var i = 0; i < weights.length; ++i) {
var impact = inputs[i] - weights[i];
weights[ i ] += rate * impact;
error += Math.abs(impact);
}
return error;
}
this.run = function(inputs) {
var position = network.search$1(inputs);
if (position == null) {
return +Infinity;
}
var neurons = network.getNeurons();
return update(neurons[position.y][position.x], inputs);
};
this.runEpoch = function(data) {
var error = 0.0;
for (var i = 0; i < data.length; ++i) {
error += this.run(data[i]);
}
return error;
};
}
// Usage example:
// -- configuration
var inputsCount = 2;
var networkWidth = 2;
var netowrkHeight = 2;
var epochsLimit = 200; // learning epochs limit
var toleratedError = 0.01; // tolerated learning error
// -- data
var data = [ // as we ca see, input data can be easly splitted into 2 areas by rows locations
[1, 2], // around area 1
[1, 3], // around area 1
[4, 5], // around area 2
[5, 5] // around area 2
];
// -- training
// Available dinstance functions:
//
// - calculateChebyshevDistance
// - calculateRectilinearDistance
// - calculateEuclideanDistance
//
var network = createRandomNetwork(inputsCount, networkWidth, netowrkHeight, calculateEuclideanDistance);
var learning = new WTALearning(network);
var error = runWTATraining(learning, data, epochsLimit, toleratedError);
// -- results
console.log('learning error: ' + error);
// winning neurons per indicated data rows (neuron coordinates: x and y, neuron output: value)
console.log(JSON.stringify(network.search$1(data[0]))); // should be in group 1 (around area 1)
console.log(JSON.stringify(network.search$1(data[1]))); // should be in group 1 (around area 1)
console.log(JSON.stringify(network.search$1(data[2]))); // should be in group 2 (around area 2)
console.log(JSON.stringify(network.search$1(data[3]))); // should be in group 2 (around area 2)
Hint: it is good to normalize input data before learning.