Skip to main content
update formatting
Source Link

How can this code be improved? I'm a novice programmer trying to learn ml by doing Itit from scratch. This code is part of a transformer model that I'm working on. AnyDo you have any ideas about how to improve it for better performance and easier reality.

import jax

import jax.numpy as jnp

class Embedding():?

import jax

import jax.numpy as jnp

class Embedding():

    def __init__(self, vocab_size, d_model, learning_rate=0.01, decay_steps=100, decay_rate=0.9):
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.global_step = 0

    def __call__(self, x):
        return jnp.take(jnp.eye(self.vocab_size), x, axis=0)

    def weights_init(self):
        key = jax.random.PRNGKey(0)
        self.embedding_matrix = jax.random.normal(key, (self.vocab_size, self.d_model))
        key = jax.random.PRNGKey(1)
        self.context_matrix = jax.random.normal(key, (self.d_model, self.vocab_size))

    def forward(self, x):
        # Embedding layer
        # x is a vector of size (vocab_size, 1)
        # embedding_matrix is a matrix of size (vocab_size, d_model)
        # hidden is a vector of size (d_model, 1)
        hidden = jnp.dot(self.embedding_matrix.T, x)
        # Context layer
        # context_matrix is a matrix of size (d_model, vocab_size)
        # output is a vector of size (vocab_size, 1)
        output = jnp.dot(self.context_matrix.T, hidden)
        # Using softmax as activation function
        prediction = jax.nn.softmax(output)
        return hidden, prediction

    def backward(self, hidden, prediction, label):
        # Calculate error
        error = jnp.array(label) - prediction

        # Calculate cross-entropy loss
        loss = -jnp.sum(jnp.array(label) * jnp.log(prediction))

        # Calculate gradient
        grad_context = jnp.dot(hidden, error.T)
        grad_embedding = jnp.dot(error, self.context_matrix.T)

        self.loss = loss
        self.update_weights(grad_context, grad_embedding)



    def update_weights(self, grad_context, grad_embedding):
        # Update weights
        self.context_matrix += self.learning_rate * grad_context
        self.embedding_matrix += self.learning_rate * grad_embedding

        # Update learning rate
        self.global_step += 1
        if self.global_step % self.decay_steps == 0:
            self.learning_rate = self.learning_rate * self.decay_rate

How can this code be improved? I'm a novice programmer trying to learn ml by doing It from scratch. This code is part of a transformer model that I'm working on. Any ideas about how to improve it for better performance and easier reality.

import jax

import jax.numpy as jnp

class Embedding():

def __init__(self, vocab_size, d_model, learning_rate=0.01, decay_steps=100, decay_rate=0.9):
    self.vocab_size = vocab_size
    self.d_model = d_model
    self.learning_rate = learning_rate
    self.decay_steps = decay_steps
    self.decay_rate = decay_rate
    self.global_step = 0

def __call__(self, x):
    return jnp.take(jnp.eye(self.vocab_size), x, axis=0)

def weights_init(self):
    key = jax.random.PRNGKey(0)
    self.embedding_matrix = jax.random.normal(key, (self.vocab_size, self.d_model))
    key = jax.random.PRNGKey(1)
    self.context_matrix = jax.random.normal(key, (self.d_model, self.vocab_size))

def forward(self, x):
    # Embedding layer
    # x is a vector of size (vocab_size, 1)
    # embedding_matrix is a matrix of size (vocab_size, d_model)
    # hidden is a vector of size (d_model, 1)
    hidden = jnp.dot(self.embedding_matrix.T, x)
    # Context layer
    # context_matrix is a matrix of size (d_model, vocab_size)
    # output is a vector of size (vocab_size, 1)
    output = jnp.dot(self.context_matrix.T, hidden)
    # Using softmax as activation function
    prediction = jax.nn.softmax(output)
    return hidden, prediction

def backward(self, hidden, prediction, label):
    # Calculate error
    error = jnp.array(label) - prediction

    # Calculate cross-entropy loss
    loss = -jnp.sum(jnp.array(label) * jnp.log(prediction))

    # Calculate gradient
    grad_context = jnp.dot(hidden, error.T)
    grad_embedding = jnp.dot(error, self.context_matrix.T)

    self.loss = loss
    self.update_weights(grad_context, grad_embedding)



def update_weights(self, grad_context, grad_embedding):
    # Update weights
    self.context_matrix += self.learning_rate * grad_context
    self.embedding_matrix += self.learning_rate * grad_embedding

    # Update learning rate
    self.global_step += 1
    if self.global_step % self.decay_steps == 0:
        self.learning_rate = self.learning_rate * self.decay_rate

How can this code be improved? I'm a novice programmer trying to learn ml by doing it from scratch. This code is part of a transformer model that I'm working on. Do you have any ideas about how to improve it for better performance and easier reality?

import jax

import jax.numpy as jnp

class Embedding():

    def __init__(self, vocab_size, d_model, learning_rate=0.01, decay_steps=100, decay_rate=0.9):
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.global_step = 0

    def __call__(self, x):
        return jnp.take(jnp.eye(self.vocab_size), x, axis=0)

    def weights_init(self):
        key = jax.random.PRNGKey(0)
        self.embedding_matrix = jax.random.normal(key, (self.vocab_size, self.d_model))
        key = jax.random.PRNGKey(1)
        self.context_matrix = jax.random.normal(key, (self.d_model, self.vocab_size))

    def forward(self, x):
        # Embedding layer
        # x is a vector of size (vocab_size, 1)
        # embedding_matrix is a matrix of size (vocab_size, d_model)
        # hidden is a vector of size (d_model, 1)
        hidden = jnp.dot(self.embedding_matrix.T, x)
        # Context layer
        # context_matrix is a matrix of size (d_model, vocab_size)
        # output is a vector of size (vocab_size, 1)
        output = jnp.dot(self.context_matrix.T, hidden)
        # Using softmax as activation function
        prediction = jax.nn.softmax(output)
        return hidden, prediction

    def backward(self, hidden, prediction, label):
        # Calculate error
        error = jnp.array(label) - prediction

        # Calculate cross-entropy loss
        loss = -jnp.sum(jnp.array(label) * jnp.log(prediction))

        # Calculate gradient
        grad_context = jnp.dot(hidden, error.T)
        grad_embedding = jnp.dot(error, self.context_matrix.T)

        self.loss = loss
        self.update_weights(grad_context, grad_embedding)



    def update_weights(self, grad_context, grad_embedding):
        # Update weights
        self.context_matrix += self.learning_rate * grad_context
        self.embedding_matrix += self.learning_rate * grad_embedding

        # Update learning rate
        self.global_step += 1
        if self.global_step % self.decay_steps == 0:
            self.learning_rate = self.learning_rate * self.decay_rate
Source Link
T3st
  • 21
  • 1

A simple word embedder only using jax

How can this code be improved? I'm a novice programmer trying to learn ml by doing It from scratch. This code is part of a transformer model that I'm working on. Any ideas about how to improve it for better performance and easier reality.

import jax

import jax.numpy as jnp

class Embedding():

def __init__(self, vocab_size, d_model, learning_rate=0.01, decay_steps=100, decay_rate=0.9):
    self.vocab_size = vocab_size
    self.d_model = d_model
    self.learning_rate = learning_rate
    self.decay_steps = decay_steps
    self.decay_rate = decay_rate
    self.global_step = 0

def __call__(self, x):
    return jnp.take(jnp.eye(self.vocab_size), x, axis=0)

def weights_init(self):
    key = jax.random.PRNGKey(0)
    self.embedding_matrix = jax.random.normal(key, (self.vocab_size, self.d_model))
    key = jax.random.PRNGKey(1)
    self.context_matrix = jax.random.normal(key, (self.d_model, self.vocab_size))

def forward(self, x):
    # Embedding layer
    # x is a vector of size (vocab_size, 1)
    # embedding_matrix is a matrix of size (vocab_size, d_model)
    # hidden is a vector of size (d_model, 1)
    hidden = jnp.dot(self.embedding_matrix.T, x)
    # Context layer
    # context_matrix is a matrix of size (d_model, vocab_size)
    # output is a vector of size (vocab_size, 1)
    output = jnp.dot(self.context_matrix.T, hidden)
    # Using softmax as activation function
    prediction = jax.nn.softmax(output)
    return hidden, prediction

def backward(self, hidden, prediction, label):
    # Calculate error
    error = jnp.array(label) - prediction

    # Calculate cross-entropy loss
    loss = -jnp.sum(jnp.array(label) * jnp.log(prediction))

    # Calculate gradient
    grad_context = jnp.dot(hidden, error.T)
    grad_embedding = jnp.dot(error, self.context_matrix.T)

    self.loss = loss
    self.update_weights(grad_context, grad_embedding)



def update_weights(self, grad_context, grad_embedding):
    # Update weights
    self.context_matrix += self.learning_rate * grad_context
    self.embedding_matrix += self.learning_rate * grad_embedding

    # Update learning rate
    self.global_step += 1
    if self.global_step % self.decay_steps == 0:
        self.learning_rate = self.learning_rate * self.decay_rate