SimpleMind: A Neural Network Implementation in JAX

SimpleMind

The SimpleMind class is a powerful yet straightforward implementation of a neural network in JAX. It supports various activation functions, optimizers, and regularization techniques, making it versatile for different machine learning tasks. With parallel backpropagation and detailed logging, it provides an efficient and transparent framework for neural network training.

The python SimpleMind class represents a neural network implemented in JAX, designed to be versatile and efficient for various machine learning tasks. Below is a detailed explanation of the class and its components.
Overview

SimpleMind is a neural network class that supports initialization, forward propagation, backpropagation, and training using JAX. It supports different activation functions and optimizers, and it also includes regularization for better generalization.

Components of SimpleMind

Initialization

    The constructor initializes the network parameters and sets up logging.

    class SimpleMind:
    def init(self, input_size, hidden_sizes, output_size, activation='relu', optimizer='adam', learning_rate=0.001, regularization=None, reg_lambda=0.01):
    # Initialization code
    
    input_size: Size of the input layer.
    hidden_sizes: List of sizes for hidden layers.
    output_size: Size of the output layer.
    activation: Activation function ('relu', 'sigmoid', 'tanh').
    optimizer: Optimizer to use ('adam', 'sgd').
    learning_rate: Learning rate for the optimizer.
    regularization: Regularization method ('l2').
    reg_lambda: Regularization strength.

    Parameter Initialization

      The _initialize_parameters method initializes the weights and biases for each layer.

      def _initialize_parameters(self):
      layer_sizes = [self.input_size] + self.hidden_sizes + [self.output_size]
      params = {}
      for i in range(len(layer_sizes) - 1):
      self.rng, layer_rng = random.split(self.rng)
      params[f'W{i}'] = random.normal(layer_rng, (layer_sizes[i], layer_sizes[i+1])) * 0.01
      params[f'b{i}'] = jnp.zeros(layer_sizes[i+1])
      return params

      Optimizer Initialization

        The _initialize_optimizer method sets up the optimizer state.

        def _initialize_optimizer(self):
        if self.optimizer_name == 'adam':
        optimizer = optax.adam(self.learning_rate)
        elif self.optimizer_name == 'sgd':
        optimizer = optax.sgd(self.learning_rate)
        else:
        raise ValueError(f"Unsupported optimizer: {self.optimizer_name}")
        return optimizer.init(self.params)

        Activation Functions

          The _activation_function method applies the chosen activation function.

          def _activation_function(self, s):
          if self.activation == 'sigmoid':
          return 1 / (1 + jnp.exp(-s))
          elif self.activation == 'tanh':
          return jnp.tanh(s)
          elif self.activation == 'relu':
          return jnp.maximum(0, s)
          else:
          raise ValueError("Unsupported activation function.")

          Forward Propagation

            The forward method performs a forward pass through the network.

            def forward(self, X):
            out = X
            for i in range(len(self.hidden_sizes)):
            W, b = self.params[f'W{i}'], self.params[f'b{i}']
            out = jnp.dot(out, W) + b
            out = self._activation_function(out)
            W, b = self.params[f'W{len(self.hidden_sizes)}'], self.params[f'b{len(self.hidden_sizes)}']
            out = jnp.dot(out, W) + b
            return out

            Backpropagation

              The backpropagate method calculates the gradients and updates the parameters.

              grads = grad(loss)(self.params)
              updates, self.opt_state = self.optimizer.update(grads, self.opt_state)
              self.params = optax.apply_updates(self.params, updates)
              
              @jit
              def backpropagate(self, X, y):
              def loss(params):
              predictions = self.forward(X)
              loss_value = jnp.mean((predictions - y) ** 2)
              if self.regularization == 'l2':
              l2_penalty = sum(jnp.sum(jnp.square(params[f'W{i}'])) for i in range(len(self.hidden_sizes) + 1))
              loss_value += self.reg_lambda * l2_penalty / 2
              return loss_value


              Training

              The train method trains the network for a specified number of epochs.

              def train(self, X, y, epochs):
              for epoch in range(epochs):
              self.params, self.opt_state = self._parallel_backpropagate(X, y)
              if epoch % 100 == 0:
              loss_value = self._calculate_loss(X, y)
              logging.info(f"Epoch {epoch}, Loss: {loss_value}")

              Parallel Backpropagation

                The _parallel_backpropagate method performs backpropagation in parallel using multiple threads.

                def _parallel_backpropagate(self, X, y):
                with ThreadPoolExecutor() as executor:
                futures = [executor.submit(self.backpropagate, X[i], y[i]) for i in range(len(X))]
                for future in as_completed(futures):
                params, opt_state = future.result()
                return params, opt_state

                Loss Calculation

                  The _calculate_loss method calculates the network’s loss.

                  @jit
                  def _calculate_loss(self, X, y):
                  output = self.forward(X)
                  loss_value = jnp.mean(jnp.square(y - output))
                  if self.regularization == 'l2':
                  loss_value += self.reg_lambda / 2 * sum(jnp.sum(jnp.square(self.params[f'W{i}'])) for i in range(len(self.hidden_sizes) + 1))
                  return loss_value

                  using the SimpleMind class:

                  if name == "main":
                  input_size = 4
                  hidden_sizes = [10, 10]
                  output_size = 1
                  learning_rate = 0.001
                  epochs = 1000
                  
                  X = jnp.array([[1.0, 2.0, 3.0, 4.0]])
                  y = jnp.array([[1.0]])
                  
                  simple_mind = SimpleMind(input_size, hidden_sizes, output_size, activation='relu', optimizer='adam', learning_rate=learning_rate, regularization='l2', reg_lambda=0.01)
                  simple_mind.train(X, y, epochs)
                  print("Final Output:", simple_mind.forward(X))

                  The SimpleMind class is a powerful yet straightforward implementation of a neural network written in python using JAX. It supports various activation functions, o

                  General Laptop Optimization: This setting balances the workload between CPU and GPU, using moderately sized hidden layers and the adam optimizer for robust performance across different hardware configurations.GPU-Optimized: This setting takes advantage of GPU parallelism with larger hidden layers and a slightly lower learning rate for better convergence. The adam optimizer is well-suited for GPUs.CPU-Optimized: This setting reduces the workload on the CPU with smaller hidden layers and uses the sgd optimizer, which is lightweight and efficient for CPUs. The tanh activation function is chosen for its efficiency on CPUs. A higher learning rate is used to speed up training on the CPU.

                  # General laptop optimization settings
                  simple_mind_general = SimpleMind(
                      input_size=10,             # Adjust based on your specific input data
                      hidden_sizes=[64, 32],     # Moderately sized hidden layers
                      output_size=1,             # Adjust based on your specific output requirements
                      activation='relu',         # Efficient activation function
                      optimizer='adam',          # Robust optimizer for varied hardware
                      learning_rate=0.001,       # Standard learning rate
                      regularization='l2',       # Use L2 regularization to prevent overfitting
                      reg_lambda=0.01            # Regularization strength
                  )
                  
                  # GPU-optimized settings
                  simple_mind_gpu = SimpleMind(
                      input_size=10,             # Adjust based on your specific input data
                      hidden_sizes=[128, 64, 32],# Larger hidden layers to leverage GPU parallelism
                      output_size=1,             # Adjust based on your specific output requirements
                      activation='relu',         # Efficient activation function for GPUs
                      optimizer='adam',          # Optimizer that works well with GPUs
                      learning_rate=0.0005,      # Slightly lower learning rate for better convergence
                      regularization='l2',       # Use L2 regularization to prevent overfitting
                      reg_lambda=0.01            # Regularization strength
                  )
                  
                  # CPU-optimized settings
                  simple_mind_cpu = SimpleMind(
                      input_size=10,             # Adjust based on your specific input data
                      hidden_sizes=[32, 16],     # Smaller hidden layers to reduce CPU load
                      output_size=1,             # Adjust based on your specific output requirements
                      activation='tanh',         # Efficient and fast activation function for CPUs
                      optimizer='sgd',           # Lightweight optimizer for CPU
                      learning_rate=0.01,        # Higher learning rate for faster convergence on CPU
                      regularization='l2',       # Use L2 regularization to prevent overfitting
                      reg_lambda=0.01            # Regularization strength
                  )
                  
                  
                  easyAGI (c) Gregory L. Magnusson MIT 2024

                  Related articles

                  easyAGI: Augmenting the Intelligence of Large Language Models

                  easy augmented general intelligence In the rapidly evolving field of artificial intelligence, the concept of Autonomous General Intelligence (AGI) represents a significant milestone. However, the journey towards AGI is complex and requires innovative approaches to streamline and simplify the development process. Enter easyAGI, a transformative framework designed to augment the intelligence of existing Large Language Models (LLMs). This article explores the core aspects of easyAGI and its impact on the landscape of AGI and LLMs. […]

                  Learn More
                  together ai

                  aGLM MASTERMIND RAGE Mixtral8x7B playground 1

                  together.ai provides a cloud environment playground for a number of LLM including Mixtral8x7Bv1. This model was chosen for the 32k ++ context window and suitable point of departure dataset for deployment of aGLM Autonomous General Learning Model. aGLM design goals include RAGE with MASTERMIND controller for logic and reasoning. The following three screenshots show the first use of aGLM recognising aGLM and MASTERMIND RAGE components to include machine.dreaming and knowledge as THOT from aGLM parse. […]

                  Learn More

                  Introducing Kuntai: DEEPDIVE

                  The Sharpest Voice in AI Knowledge Delivery Welcome to the Kuntai: DEEPDIVE Podcast, a no-nonsense, intellectually fierce exploration into the ever-evolving world of AI, data, and innovation. Hosted at rage.pythai.net, Kuntai’s mission is simple: challenge the boundaries of knowledge, provoke deeper thought, and leave no stone unturned in the pursuit of intellectual mastery. What to Expect from Kuntai: DeepDive In this exclusive podcast series, we bring you the brilliant insights crafted by Kuntai—18 meticulously written […]

                  Learn More