Leapcell: The Best of Serverless Web Hosting
Building a Neural Network from Scratch with Go: Principles, Structure, and Implementation
This article will introduce how to use the Go programming language to build a simple neural network from scratch and demonstrate its workflow through the Iris classification task. It will combine principle explanations, code implementations, and visual structure displays to help readers understand the core mechanisms of neural networks.
Ⅰ. Basic Principles and Structure of Neural Networks
A neural network is a computational model that simulates biological neurons, achieving non-linear mapping through connections of nodes across multiple layers. A typical three-layer neural network structure includes an input layer, a hidden layer, and an output layer. Nodes in each layer are connected via weights and biases, and inter-layer transmission is processed through activation functions. Below is a schematic diagram of a simple three-layer neural network structure (drawn with ASCII characters):
+-----------+ +-----------+ +-----------+
| Input Layer | | Hidden Layer | | Output Layer |
| 4 Nodes | | 3 Nodes | | 3 Nodes |
+-----------+ +-----------+ +-----------+
↑ ↑ ↑
│ Weights │ Weights │ Weights │
├───────────────┼───────────────┼───────────────┤
↓ ↓ ↓
+-----------+ +-----------+ +-----------+
| Bias | | Bias | | Bias |
+-----------+ +-----------+ +-----------+
↓ ↓ ↓
+-----------+ +-----------+ +-----------+
| Activation| | Activation| | Activation|
| Function | | Function | | Function |
+-----------+ +-----------+ +-----------+
Core Concepts:
-
Forward Propagation
Input data undergoes linear transformation through weight matrices (input × weights + bias
), then introduces non-linearity via activation functions, propagating layer by layer to the output layer.
Formula Examples:- Hidden layer input: ( Z_1 = X \cdot W_1 + b_1 )
- Hidden layer output: ( A_1 = \sigma(Z_1) ) ((\sigma) is the Sigmoid function)
- Output layer input: ( Z_2 = A_1 \cdot W_2 + b_2 )
- Output layer output: ( A_2 = \sigma(Z_2) )
-
Backpropagation
By calculating the error between predicted and true values (such as mean squared error), weights and biases in each layer are updated in reverse using the chain rule to optimize model parameters.
Key Steps:- Calculate output error: ( \delta_2 = A_2 - Y )
- Hidden layer error: ( \delta_1 = \delta_2 \cdot W_2^T \odot \sigma'(Z_1) ) ((\odot) denotes element-wise multiplication)
- Update weights: ( W_2 \leftarrow W_2 - \eta \cdot A_1^T \cdot \delta_2 ), ( W_1 \leftarrow W_1 - \eta \cdot X^T \cdot \delta_1 )
- Update biases: ( b_2 \leftarrow b_2 - \eta \cdot \sum \delta_2 ), ( b_1 \leftarrow b_1 - \eta \cdot \sum \delta_1 ) ((\eta) is the learning rate, (\sigma') is the derivative of the activation function)
Ⅱ. Key Design of Neural Network Implementation in Go
1. Data Structure Definition
Use the gonum.org/v1/gonum/mat
package in Go for matrix operations, defining network structures and parameters:
// neuralNet stores trained neural network parameters
type neuralNet struct {
config neuralNetConfig // Network configuration
wHidden *mat.Dense // Hidden layer weight matrix
bHidden *mat.Dense // Hidden layer bias vector
wOut *mat.Dense // Output layer weight matrix
bOut *mat.Dense // Output layer bias vector
}
// neuralNetConfig defines network architecture and training parameters
type neuralNetConfig struct {
inputNeurons int // Number of input layer nodes (e.g., 4 features of Iris)
outputNeurons int // Number of output layer nodes (e.g., 3 Iris species)
hiddenNeurons int // Number of hidden layer nodes (tunable hyperparameter)
numEpochs int // Number of training epochs
learningRate float64 // Learning rate
}
2. Activation Function and Its Derivative
Select the Sigmoid function as the activation function, whose derivative can be quickly calculated based on the function value, suitable for backpropagation:
// sigmoid Sigmoid activation function
func sigmoid(x float64) float64 {
return 1.0 / (1.0 + math.Exp(-x))
}
// sigmoidPrime Derivative of the Sigmoid function
func sigmoidPrime(x float64) float64 {
s := sigmoid(x)
return s * (1.0 - s)
}
3. Backpropagation Training Logic
Parameter Initialization
Initialize weights and biases with random numbers to ensure the network can learn:
func (nn *neuralNet) train(x, y *mat.Dense) error {
randGen := rand.New(rand.NewSource(time.Now().UnixNano())) // Random number generator
// Initialize weights and biases for hidden and output layers
wHidden := mat.NewDense(nn.config.inputNeurons, nn.config.hiddenNeurons, nil)
bHidden := mat.NewDense(1, nn.config.hiddenNeurons, nil)
wOut := mat.NewDense(nn.config.hiddenNeurons, nn.config.outputNeurons, nil)
bOut := mat.NewDense(1, nn.config.outputNeurons, nil)
// Fill parameter matrices with random numbers
for _, param := range [][]*mat.Dense{{wHidden, bHidden}, {wOut, bOut}} {
for _, m := range param {
raw := m.RawMatrix().Data
for i := range raw {
raw[i] = randGen.Float64() // Random values in [0, 1)
}
}
}
// Invoke backpropagation for training
return nn.backpropagate(x, y, wHidden, bHidden, wOut, bOut)
}
Core Backpropagation Logic
Implement error backpropagation and parameter updates through matrix operations. The code uses the Apply
method to batch process activation functions and derivatives:
func (nn *neuralNet) backpropagate(x, y, wHidden, bHidden, wOut, bOut *mat.Dense) error {
for epoch := 0; epoch < nn.config.numEpochs; epoch++ {
// Forward propagation to calculate outputs of each layer
hiddenInput := new(mat.Dense).Mul(x, wHidden) // Hidden layer linear input: X·W_hidden
hiddenInput.Apply(func(_, col int, v float64) float64 { // Add bias term
return v + bHidden.At(0, col)
}, hiddenInput)
hiddenAct := new(mat.Dense).Apply(sigmoid, hiddenInput) // Hidden layer activated output
outputInput := new(mat.Dense).Mul(hiddenAct, wOut) // Output layer linear input: A_hidden·W_out
outputInput.Apply(func(_, col int, v float64) float64 { // Add bias term
return v + bOut.At(0, col)
}, outputInput)
output := new(mat.Dense).Apply(sigmoid, outputInput) // Output layer activated output
// Backpropagation to calculate errors and gradients
error := new(mat.Dense).Sub(y, output) // Output error: Y - A_out
// Calculate output layer gradients
outputSlope := new(mat.Dense).Apply(sigmoidPrime, outputInput) // σ'(Z_out)
dOutput := new(mat.Dense).MulElem(error, outputSlope) // δ_out = error * σ'(Z_out)
// Calculate hidden layer gradients
hiddenError := new(mat.Dense).Mul(dOutput, wOut.T()) // Error backpropagation: δ_out·W_out^T
hiddenSlope := new(mat.Dense).Apply(sigmoidPrime, hiddenInput) // σ'(Z_hidden)
dHidden := new(mat.Dense).MulElem(hiddenError, hiddenSlope) // δ_hidden = δ_out·W_out^T * σ'(Z_hidden)
// Update weights and biases (stochastic gradient descent)
wOut.Add(wOut, new(mat.Dense).Scale(nn.config.learningRate, new(mat.Dense).Mul(hiddenAct.T(), dOutput)))
bOut.Add(bOut, new(mat.Dense).Scale(nn.config.learningRate, sumAlongAxis(0, dOutput)))
wHidden.Add(wHidden, new(mat.Dense).Scale(nn.config.learningRate, new(mat.Dense).Mul(x.T(), dHidden)))
bHidden.Add(bHidden, new(mat.Dense).Scale(nn.config.learningRate, sumAlongAxis(0, dHidden)))
}
// Save trained parameters
nn.wHidden, nn.bHidden, nn.wOut, nn.bOut = wHidden, bHidden, wOut, bOut
return nil
}
4. Forward Prediction Function
After training, use trained weights and biases for forward propagation to output predictions:
func (nn *neuralNet) predict(x *mat.Dense) (*mat.Dense, error) {
// Check if parameters exist
if nn.wHidden == nil || nn.wOut == nil {
return nil, errors.New("neural network not trained")
}
hiddenAct := new(mat.Dense).Mul(x, nn.wHidden).Apply(func(_, col int, v float64) float64 {
return v + nn.bHidden.At(0, col)
}, nil).Apply(sigmoid, nil)
output := new(mat.Dense).Mul(hiddenAct, nn.wOut).Apply(func(_, col int, v float64) float64 {
return v + nn.bOut.At(0, col)
}, nil).Apply(sigmoid, nil)
return output, nil
}
Ⅲ. Data Processing and Experimental Validation
1. Dataset Preparation
Use the classic Iris Dataset, containing 4 features (sepal length, sepal width, petal length, petal width) and 3 species (Setosa, Versicolor, Virginica). Data preprocessing steps:
- Convert species labels to one-hot encoding, e.g., Setosa corresponds to
[1, 0, 0]
, Versicolor to[0, 0, 1]
. - Split 80% of the data into the training set and 20% into the test set, and add small random noise to increase training difficulty.
- Sample data (excerpt from
train.csv
):
sepal_length,sepal_width,petal_length,petal_width,setosa,virginica,versicolor
0.0873,0.6687,0.0,0.0417,1.0,0.0,0.0
0.7232,0.4533,0.6949,0.967,0.0,1.0,0.0
0.6617,0.4567,0.6580,0.6567,0.0,0.0,1.0
2. Main Program Flow
Read Data and Convert to Matrices
func main() {
// Read training data file
f, err := os.Open("data/train.csv")
if err != nil {
log.Fatalf("failed to open file: %v", err)
}
defer f.Close()
reader := csv.NewReader(f)
reader.FieldsPerRecord = 7 // 4 features + 3 labels
rawData, err := reader.ReadAll()
if err != nil {
log.Fatalf("failed to read CSV: %v", err)
}
// Parse data into input features (X) and labels (Y)
numSamples := len(rawData) - 1 // Skip header
inputsData := make([]float64, 4*numSamples)
labelsData := make([]float64, 3*numSamples)
for i, record := range rawData {
if i == 0 {
continue // Skip header
}
for j, val := range record {
fVal, err := strconv.ParseFloat(val, 64)
if err != nil {
log.Fatalf("invalid value: %v", val)
}
if j < 4 {
inputsData[(i-1)*4+j] = fVal // First 4 columns are features
} else {
labelsData[(i-1)*3+(j-4)] = fVal // Last 3 columns are labels
}
}
}
inputs := mat.NewDense(numSamples, 4, inputsData)
labels := mat.NewDense(numSamples, 3, labelsData)
}
Configure Network Parameters and Train
// Define network structure: 4 inputs, 3 hidden nodes, 3 outputs
config := neuralNetConfig{
inputNeurons: 4,
outputNeurons: 3,
hiddenNeurons: 5,
numEpochs: 8000, // Train for 5000 epochs
learningRate: 0.2, // Learning rate
}
network := newNetwork(config)
if err := network.train(inputs, labels); err != nil {
log.Fatalf("training failed: %v", err)
}
Test Model Accuracy
// Read test data and predict
predictions, err := network.predict(testInputs)
if err != nil {
log.Fatalf("prediction failed: %v", err)
}
// Calculate classification accuracy
trueCount := 0
numPreds, _ := predictions.Dims()
for i := 0; i < numPreds; i++ {
// Get true label (one-hot to index)
trueLabel := mat.Row(nil, i, testLabels)
trueClass := -1
for j, val := range trueLabel {
if val == 1.0 {
trueClass = j
break
}
}
// Get class with highest probability in predictions
predRow := mat.Row(nil, i, predictions)
maxVal := floats.Min(predRow)
predClass := -1
for j, val := range predRow {
if val > maxVal {
maxVal = val
predClass = j
}
}
if trueClass == predClass {
trueCount++
}
}
fmt.Printf("Accuracy: %.2f%%\n", float64(trueCount)/float64(numPreds)*100)
Ⅳ. Experimental Results and Summary
After 8000 training epochs, the model achieves approximately 98% classification accuracy on the test set (results may vary slightly due to random initialization). This demonstrates that even a simple three-layer neural network can effectively solve non-linear classification problems.
Core Advantages:
-
Pure Go Implementation: No reliance on C extensions (no
cgo
), can be compiled into static binary files, suitable for cross-platform deployment. -
Matrix Abstraction: Numerical computations based on the
gonum/mat
package, with clear code structure and easy extensibility.
Improvement Directions:
- Experiment with different activation functions (e.g., ReLU) or optimizers (e.g., Adam).
- Add regularization (e.g., L2 regularization) to prevent overfitting.
- Support multiple hidden layers to build deeper neural networks.
Leapcell: The Best of Serverless Web Hosting
Finally, recommend the best platform for deploying Go services: Leapcell
🚀 Build with Your Favorite Language
Develop effortlessly in JavaScript, Python, Go, or Rust.
🌍 Deploy Unlimited Projects for Free
Only pay for what you use—no requests, no charges.
⚡ Pay-as-You-Go, No Hidden Costs
No idle fees, just seamless scalability.
🔹 Follow us on Twitter: @LeapcellHQ
Top comments (1)
Did you face any particular challenges using the Sigmoid activation function instead of others like ReLU in your Go neural network implementation? It would be interesting to see a future post comparing performance between different activation functions or optimizers in this Go-based framework!