2024-12-03 16:50:19 -06:00
|
|
|
class NeuralNetwork {
|
|
|
|
constructor(inputNodes, hiddenNodes, outputNodes) {
|
|
|
|
this.inputNodes = inputNodes;
|
|
|
|
this.hiddenNodes = hiddenNodes;
|
|
|
|
this.outputNodes = outputNodes;
|
|
|
|
|
|
|
|
// Initialize weights with random values between -1 and 1
|
|
|
|
this.weights_ih = this.initializeWeights(this.hiddenNodes, this.inputNodes);
|
|
|
|
this.weights_ho = this.initializeWeights(this.outputNodes, this.hiddenNodes);
|
|
|
|
|
|
|
|
// Biases
|
|
|
|
this.bias_h = this.initializeBiases(this.hiddenNodes);
|
|
|
|
this.bias_o = this.initializeBiases(this.outputNodes);
|
|
|
|
|
|
|
|
// Learning rate
|
|
|
|
this.learningRate = 0.1;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Helper method to initialize weights
|
|
|
|
initializeWeights(rows, cols) {
|
|
|
|
return Array(rows).fill().map(() => Array(cols).fill().map(() => Math.random() * 2 - 1));
|
|
|
|
}
|
|
|
|
|
|
|
|
// Helper method to initialize biases
|
|
|
|
initializeBiases(size) {
|
|
|
|
return Array(size).fill().map(() => Math.random() * 2 - 1);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Sigmoid activation function
|
|
|
|
sigmoid(x) {
|
|
|
|
return 1 / (1 + Math.exp(-x));
|
|
|
|
}
|
|
|
|
|
|
|
|
// Sigmoid derivative (used for backpropagation)
|
|
|
|
sigmoidDerivative(x) {
|
|
|
|
return x * (1 - x);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Feedforward function
|
|
|
|
feedforward(inputArray) {
|
|
|
|
// Calculate hidden layer outputs
|
|
|
|
let hidden = this.weights_ih.map((row, i) =>
|
|
|
|
this.sigmoid(row.reduce((sum, weight, j) => sum + weight * inputArray[j], 0) + this.bias_h[i])
|
|
|
|
);
|
|
|
|
|
|
|
|
// Calculate output layer outputs
|
|
|
|
let output = this.weights_ho.map((row, i) =>
|
|
|
|
this.sigmoid(row.reduce((sum, weight, j) => sum + weight * hidden[j], 0) + this.bias_o[i])
|
|
|
|
);
|
|
|
|
|
|
|
|
return output;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Train the network
|
|
|
|
train(inputArray, targetArray) {
|
|
|
|
// Feedforward
|
|
|
|
let hidden = this.weights_ih.map((row, i) =>
|
|
|
|
this.sigmoid(row.reduce((sum, weight, j) => sum + weight * inputArray[j], 0) + this.bias_h[i])
|
|
|
|
);
|
|
|
|
let outputs = this.weights_ho.map((row, i) =>
|
|
|
|
this.sigmoid(row.reduce((sum, weight, j) => sum + weight * hidden[j], 0) + this.bias_o[i])
|
|
|
|
);
|
|
|
|
|
|
|
|
// Calculate the output errors
|
|
|
|
let outputErrors = targetArray.map((target, i) => target - outputs[i]);
|
|
|
|
|
|
|
|
// Calculate output gradients
|
|
|
|
let outputGradients = outputs.map((output, i) =>
|
|
|
|
this.sigmoidDerivative(output) * outputErrors[i] * this.learningRate
|
|
|
|
);
|
|
|
|
|
|
|
|
// Calculate hidden errors
|
|
|
|
let hiddenErrors = this.weights_ho[0].map((_, i) =>
|
|
|
|
this.weights_ho.reduce((sum, row) => sum + row[i] * outputErrors[0], 0)
|
|
|
|
);
|
|
|
|
|
|
|
|
// Calculate hidden gradients
|
|
|
|
let hiddenGradients = hidden.map((h, i) =>
|
|
|
|
this.sigmoidDerivative(h) * hiddenErrors[i] * this.learningRate
|
|
|
|
);
|
|
|
|
|
|
|
|
// Update weights and biases
|
|
|
|
this.weights_ho = this.weights_ho.map((row, i) =>
|
|
|
|
row.map((weight, j) => weight + outputGradients[i] * hidden[j])
|
|
|
|
);
|
|
|
|
this.weights_ih = this.weights_ih.map((row, i) =>
|
|
|
|
row.map((weight, j) => weight + hiddenGradients[i] * inputArray[j])
|
|
|
|
);
|
|
|
|
this.bias_o = this.bias_o.map((bias, i) => bias + outputGradients[i]);
|
|
|
|
this.bias_h = this.bias_h.map((bias, i) => bias + hiddenGradients[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Example usage:
|
|
|
|
|
|
|
|
// Example dataset (XOR problem)
|
2024-12-03 17:10:23 -06:00
|
|
|
let strings = ["hello", "neural", "world", "how are", "you", "?"]
|
2024-12-03 16:50:19 -06:00
|
|
|
|
|
|
|
let inputs = [
|
2024-12-03 17:10:23 -06:00
|
|
|
[0, 0, 0, 0, 0 ,0],
|
|
|
|
[1, 0, 0, 0, 0, 0],
|
|
|
|
[1, 0, 1, 0, 0, 0],
|
|
|
|
[1, 1, 1, 0, 0, 0],
|
|
|
|
[1, 1, 1, 0, 1, 0],
|
|
|
|
[1, 1, 1, 0, 1, 1],
|
|
|
|
[1, 1, 1, 1, 1, 1]
|
2024-12-03 16:50:19 -06:00
|
|
|
];
|
2024-12-03 17:05:16 -06:00
|
|
|
|
|
|
|
let nn = new NeuralNetwork(strings.length, inputs.length, 1); // Increased hidden nodes to 4
|
|
|
|
|
2024-12-03 16:50:19 -06:00
|
|
|
let targets = [
|
|
|
|
[0],
|
|
|
|
[1],
|
|
|
|
[2],
|
2024-12-03 17:10:23 -06:00
|
|
|
[3],
|
|
|
|
[4],
|
|
|
|
[5],
|
|
|
|
[6]
|
2024-12-03 16:50:19 -06:00
|
|
|
];
|
|
|
|
|
2024-12-03 17:25:54 -06:00
|
|
|
|
|
|
|
|
|
|
|
// Example dataset (XOR problem)
|
|
|
|
let strings2 = ["i'm", "fine", "myself."]
|
|
|
|
|
|
|
|
let inputs2 = [
|
|
|
|
[0, 0, 0],
|
|
|
|
[1, 0, 0],
|
|
|
|
[1, 0, 1],
|
|
|
|
[1, 1, 1]
|
|
|
|
];
|
|
|
|
|
|
|
|
let nn2 = new NeuralNetwork(strings2.length, inputs2.length, 1); // Increased hidden nodes to 4
|
|
|
|
|
|
|
|
let targets2 = [
|
|
|
|
[0],
|
|
|
|
[1],
|
|
|
|
[2],
|
|
|
|
[3]
|
|
|
|
];
|
|
|
|
|
2024-12-03 17:05:16 -06:00
|
|
|
|
2024-12-03 16:50:19 -06:00
|
|
|
// Training the neural network
|
2024-12-03 17:25:54 -06:00
|
|
|
function train(nn, inputs, targets) {
|
|
|
|
if (targets.length != inputs.length) {
|
|
|
|
throw new Error("You dummy, targets length should be equal to inputs length.")
|
|
|
|
}
|
|
|
|
for (let i = 0; i < 50000; i++) { // Increased training iterations
|
|
|
|
let index = Math.floor(Math.random() * 4);
|
|
|
|
nn.train(inputs[index], targets[index]);
|
|
|
|
}
|
2024-12-03 16:50:19 -06:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2024-12-03 17:25:54 -06:00
|
|
|
function think(nn, feeds, strings) {
|
|
|
|
if (strings.length != feeds[0].length) {
|
|
|
|
throw new Error("You dummy, strings array length should be equal to a feed input length.")
|
|
|
|
}
|
|
|
|
feeds.forEach((item, index) => {
|
|
|
|
let feededVal = Math.round(nn.feedforward(item))
|
|
|
|
console.log([feededVal, item])
|
|
|
|
if (Math.round(feededVal) == 1) {
|
|
|
|
let sum = 0;
|
|
|
|
item.forEach(num => sum += num);
|
|
|
|
try { console.log(strings[sum - 1]) } catch {}
|
|
|
|
}
|
|
|
|
})
|
2024-12-03 16:50:19 -06:00
|
|
|
}
|
|
|
|
|
2024-12-03 17:25:54 -06:00
|
|
|
train(nn, inputs, targets)
|
|
|
|
think(nn, inputs, strings)
|
|
|
|
|
|
|
|
train(nn2, inputs2, targets2)
|
|
|
|
think(nn2, inputs2, strings2)
|
|
|
|
|