Languages
[Edit]
EN

JavaScript - Kohonen Neural Network (WTA Learning)

10 points
Created by:
DoLLot
614

In this short article, we would like to show how to implement winner-takes-all learning (WTA) in JavaScript.

Practical example:

// ONLINE-RUNNER:browser;

// Common utils:

function detectArrayHeight(array) {
    if (array == null) {
        return 0;
    }
    return array.length;
}
	
function detectArrayWidth(array) {
    if (array == null || array.length == 0) {
        return 0;
    }
    var tmp = array[0];
    if (tmp == null) {
        return 0;
    }
    return tmp.length;
}


// Distance neuron:

// -- utils

function createDistanceNeuron(inputsCount, distance, createWeight) {
    var weights = Array(inputsCount);
    for (var i = 0; i < weights.length; ++i) {
        weights[i] = createWeight(i);
    }
    return new DistanceNeuron(weights, distance);
}

function createZeroedNeuron(inputsCount, distance) {
    return createDistanceNeuron(inputsCount, distance, function() {
        return 0.0;
    });
}

function createRandomNeuron(inputsCount, distance) {
    return createDistanceNeuron(inputsCount, distance, function() {
        return Math.random();
    });
}

// -- distances

// Source: https://dirask.com/posts/JavaScript-Chebyshev-Distance-function-p5q6Rp
//
function calculateChebyshevDistance(a, b) {
    if (a.length === 0 || a.length !== b.length) {
      	return NaN;
    }
    var max = Math.abs(a[0] - b[0]);
    for (var i = 1; i < a.length; ++i) {
        var tmp = Math.abs(a[i] - b[i]);
        if (tmp > max) {
            max = tmp;
        }
    }
    return max;
}

function calculateEuclideanDistance(a, b) {
    if (a.length === 0 || a.length !== b.length) {
      	return NaN;
    }
	var sum = 0.0;
	for (var i = 0; i < a.length; ++i) {
		var tmp = a[i] - b[i];
		sum += tmp * tmp;
	}
	return Math.sqrt(sum);
}

// Source: https://dirask.com/posts/JavaScript-Rectilinear-Distance-function-1XgJgj
//
function calculateRectilinearDistance(a, b) {
    if (a.length === 0 || a.length !== b.length) {
      	return NaN;
    }
    var sum = 0.0;
    for (var i = 0; i < a.length; ++i) {
        sum += Math.abs(a[ i ] - b[i]);
    }
    return sum;
}

// -- neuron

function DistanceNeuron(weights, distance) {
	this.getWeights = function() {
		return weights;
	};
	this.getDistance = function() {
		return distance;
	};
	this.randomize$1 = function() {
		this.randomize$2(0.0, 1.0);
	};
	this.randomize$2 = function(min, max) {
		var range = max - min;
        for (var i = 0; i < weights.length; ++i) {
            weights[i] = min + range * Math.random();
        }
	};
	this.compute = function(inputs) {
		return distance.call(null, weights, inputs);
    };
}


// Distance network:

// -- utils

function createDistanceNetwork(inputsCount, width, height, createNeuron) {
    var neurons = Array(height);
    for (var i = 0; i < neurons.length; ++i) {
        var tmp = Array(width);
        for (var j = 0; j < tmp.length; ++j) {
            tmp[j] = createNeuron(i, j);
        }
        neurons[i] = tmp;
    }
    return new DistanceNetwork(inputsCount, neurons);
}

function createZeroedNetwork(inputsCount, width, height, distance) {
    return createDistanceNetwork(inputsCount, width, height, function() {
        return createZeroedNeuron(inputsCount, distance);
    });
}

function createRandomNetwork(inputsCount, width, height, distance) {
    return createDistanceNetwork(inputsCount, width, height, function() {
        return createRandomNeuron(inputsCount, distance);
    });
}

// -- network

function DistanceNetwork(inputsCount, neurons) {
	var width = detectArrayWidth(neurons);
	var height = detectArrayHeight(neurons);
    var outputsCount = width * height;

	this.getInputsCount = function() {
		return inputsCount;
	};
  
  	this.getOutputsCount = function() {
		return outputsCount;
	};

	this.getWidth = function() {
		return width;
	};

	this.getHeight = function() {
		return height;
	};

	this.getNeurons = function() {
		return neurons;
	};

	this.getNeuron$1 = function(position) {
		return this.getNeuron$2(position.x, position.y);
	};
	
	this.getNeuron$2 = function(positionX, positionY) {
		var tmp = neurons[positionY];
        if (tmp) {
            return tmp[positionX];
        }
		return null;
	};
	
  	this.randomize$1 = function() {
		for (var y = 0; y < height; ++y) {
			var tmp = neurons[y];
			for (var x = 0; x < width; ++x) {
				tmp[x].random$1();
            }
		}
	};

	this.randomize$2 = function(min, max) {
		for (var y = 0; y < height; ++y) {
			var tmp = neurons[y];
			for (var x = 0; x < width; ++x) {
				tmp[x].randomize$2(min, max);
            }
		}
	};

	// Computes neurons outputs.
	//
	this.compute$1 = function(inputs) {
        var index = -1;
        var outputs = new Array(outputsCount);
		for (var y = 0; y < height; ++y) {
			var tmp = neurons[y];
			for (var x = 0; x < width; ++x) {
				outputs[++index] = tmp[x].compute(inputs);
            }
		}
		return outputs;
	}

  	// Computes neurons outputs for each data row.
	//
	this.compute$2 = function(data) {
		var result = Array(data.length);
		for (var i = 0; i < data.length; ++i) {
			result[i] = this.compute$1(data[i]);
        }
		return result;
	}

	// Searches for the winning neuron and returns his position.
    // Winning neuron is neuron that has weights closest to the inputs values (smallest distance).
	//
	this.search$1 = function(inputs) {
		var positionX = -1;
		var positionY = -1;
        var index = -1;
		var value = +Infinity;
        var outputs = new Array(outputsCount);
		for (var y = 0; y < height; ++y) {
			var tmp = neurons[y];
			for (var x = 0; x < width; ++x) {
				var output = tmp[x].compute(inputs);
				if (output < value) {
					positionX = x;
					positionY = y;
					value = output;
				}
				outputs[++index] = output;
			}
		}
		if (positionX == -1) {
			return null;
        }
		return {
            x: positionX,
            y: positionY,
            value: value
         };
	}
	
    // Searches for the winning neuron of each data row and returns their positions.
	//
	this.search$2 = function(data) {
		var result = Array(data.length);
		for (var i = 0; i < data.length; ++i) {
			result[i] = this.search$1(data[i]);
        }
		return result;
	}
}


// WTA learning (winner-takes-all)

// -- utils

function runWTATraining(learning, data, epochsLimit, toleratedError, rateModel, radiusModel) {
    var error = +Infinity;
    for (var i = 0; i < epochsLimit && error > toleratedError; ++i) {
        error = learning.runEpoch(data);
    }
    return error;
}

// -- learning

function WTALearning(network) {
	var rate = 0.01; // learning rate

    this.getRate = function() {
		return rate;
	};

	this.setRate = function(value) {
		rate = value;
	};

	function update(neuron, inputs) {
		var weights = neuron.getWeights();
		var error = 0.0;
		for (var i = 0; i < weights.length; ++i) {
			var impact = inputs[i] - weights[i];
			weights[ i ] += rate * impact;
			error += Math.abs(impact);
		}
		return error;
	}
	
	this.run = function(inputs) {
		var position = network.search$1(inputs);
		if (position == null) {
			return +Infinity;
        }
		var neurons = network.getNeurons();
		return update(neurons[position.y][position.x], inputs);
	};

	this.runEpoch = function(data) {
		var error = 0.0;
		for (var i = 0; i < data.length; ++i) {
			error += this.run(data[i]);
        }
		return error;
	};
}


// Usage example:

// -- configuration

var inputsCount = 2;
var networkWidth = 2;
var netowrkHeight = 2;

var epochsLimit = 200;      // learning epochs limit
var toleratedError = 0.01;  // tolerated learning error

// -- data

var data = [ // as we ca see, input data can be easly splitted into 2 areas by rows locations
    [1, 2],  // around area 1
    [1, 3],  // around area 1
    [4, 5],  // around area 2
    [5, 5]   // around area 2
]; 

// -- training

// Available dinstance functions:
//
//   - calculateChebyshevDistance
//   - calculateRectilinearDistance
//   - calculateEuclideanDistance
//
var network = createRandomNetwork(inputsCount, networkWidth, netowrkHeight, calculateEuclideanDistance);

var learning = new WTALearning(network);

var error = runWTATraining(learning, data, epochsLimit, toleratedError);

// -- results

console.log('learning error: ' + error);

// winning neurons per indicated data rows (neuron coordinates: x and y, neuron output: value)
console.log(JSON.stringify(network.search$1(data[0])));  // should be in group 1 (around area 1)
console.log(JSON.stringify(network.search$1(data[1])));  // should be in group 1 (around area 1)
console.log(JSON.stringify(network.search$1(data[2])));  // should be in group 2 (around area 2)
console.log(JSON.stringify(network.search$1(data[3])));  // should be in group 2 (around area 2)

Hint: it is good to normalize input data before learning.

 

See also

  1. JavaScript - distance neuron

  2. JavaScript - distance network

  3. JavaScript - Kohonen Neural Network (WTM Learning)

  4. JavaScript - Chebyshev Distance function

  5. JavaScript - Rectilinear Distance function

Alternative titles

  1. JavaScript - SOM (WTA Learning)
  2. JavaScript - Self Organizing Map (WTA Learning)
  3. JavaScript - Kohonen Neural Network (winner-takes-all learning)
  4. JavaScript - SOM (winner-takes-all learning)
  5. JavaScript - Self Organizing Map (winner-takes-all learning)
Donate to Dirask
Our content is created by volunteers - like Wikipedia. If you think, the things we do are good, donate us. Thanks!
Join to our subscribers to be up to date with content, news and offers.
Native Advertising
🚀
Get your tech brand or product in front of software developers.
For more information Contact us
Dirask - we help you to
solve coding problems.
Ask question.

❤️💻 🙂

Join