I know that there are some limitations with strings and jax, but I imagine it should be possible to choose a string element from a list based on the value of a jax numpy array. For example below I would like to print ["you", "Hello"]:
import jax
import jax.numpy as jnp
A = ["Hello", "there", "you"]
B = jnp.array([2, 0])
def get_value(index):
return A[index]
C = jax.vmap(get_value)(B)
print(C)
But this gives the error: jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[].
I don't mind replacing the list with any data structure that would make this work.