0

I'm trying to access a python list inside a function that can be run in eager mode or graph mode as shown below:

import tensorflow as tf
import numpy as np

class SlotGenerator:
  def __init__(self):
    pass  # No need for initialization here

  def call(self):
      real_part = tf.random.uniform((1, 1, 30720), dtype=tf.float32)
      imag_part = tf.random.uniform((1, 1, 30720), dtype=tf.float32)
      return tf.complex(real_part, imag_part)

class WaveformGenerator:
  def __init__(self, slot_numbers):
    self.slotgens = [SlotGenerator() for _ in range(len(slot_numbers))]
    
  def gen_single_slot(self, slot_num):
    # Generate random slot_matrix with desired dimensions
    return self.slotgens[slot_num].call()
  
  def gen_slots(self, slot_numbers, batch_size=1, num_ant=1, num_time_samples=30720):
      """Generate slots specified by slot_nos"""
        
      slot_shape = tf.TensorShape([batch_size, num_ant, num_time_samples])

      slots = tf.map_fn(
            lambda slot_num: self.gen_single_slot(slot_num),
            slot_numbers,
            fn_output_signature=tf.TensorSpec(shape=slot_shape, dtype=tf.complex64)
        )
        
      slots = tf.squeeze(slots, axis=[1]) # tf.map_fn adds an extra dimension at the beginning of the tensor
      return tf.concat(slots, axis=0)

  def __call__(self, slot_numbers, eager_mode=False):

    if not eager_mode:
        self.gen_slots = tf.function(self.gen_slots)
    
    slots = self.gen_slots(slot_numbers)

    return slots

# Example usage

slot_numbers = [0,1,2,3,4]  # Example number of slots
slot_numbers = np.array(slot_numbers, dtype=np.int32)
waveform_generator = WaveformGenerator(slot_numbers)

eager_mode = False   
waveform = waveform_generator(slot_numbers, eager_mode=eager_mode)

print(waveform.shape)  # Expected output shape: ([5, 1, 30720])

This runs fine in eager mode(set eager_mode = True), but fails in graph mode(set eager_mode = False) with the following error:

TypeError: list indices must be integers or slices, not SymbolicTensor

Is there a workaround for this list indexing problem in graph mode? Or does graph mode support types other than python lists where i can store class objects?

0

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.