2

I know that there are some limitations with strings and jax, but I imagine it should be possible to choose a string element from a list based on the value of a jax numpy array. For example below I would like to print ["you", "Hello"]:

import jax
import jax.numpy as jnp


A = ["Hello", "there", "you"]
B = jnp.array([2, 0])


def get_value(index):
  return A[index]


C = jax.vmap(get_value)(B)


print(C)

But this gives the error: jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[].

I don't mind replacing the list with any data structure that would make this work.

1 Answer 1

1

In the course of normal JAX computation, there is no way to index into a list within a JAX transformation like vmap. Fundamentally, the reason for this is that list indexing is a compile-time operation executed on the host; while transformed operations execute at runtime on the target device. The list and the indices do not exist in the same runtime process, and so they cannot interact.

One way to bridge this divide is by using a host callback: this is where you effectively tell the device at runtime to pause its execution and send a message back to the host, then wait for the host to send back its data. External Callbacks in JAX describes the various possibilities. Note that any data sent back to the device must be a valid JAX value, which does not include strings.

It's not clear from your example what you wish to do with the results of your list indexing (strings are not valid return values from a transformed JAX function) but here's an example of using jax.debug.callback to index into a list of strings at runtime and print the entries:

import jax
import jax.numpy as jnp

A = ["Hello", "there", "you"]
B = jnp.array([2, 0])

def get_value(index):
  jax.debug.callback(lambda index: print(A[index]), index)

jax.vmap(get_value)(B)
you
Hello
Sign up to request clarification or add additional context in comments.

2 Comments

thanks, that looks useful! but how do I store what vmap returns to a variable instead of just printing?
You cannot store a string in a variable within a JAX computation: JAX does not support string types.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.