3

I load in a dataset as such:

import tensorflow_datasets as tfds

ds = tfds.load(
    'caltech_birds2010',
    split='train',
    as_supervised=False)

And this function works fine:

import tensorflow as tf

@tf.function
def pad(image,label):
    return (tf.image.resize_with_pad(image,32,32),label)

ds = ds.map(pad)

But when when I try mapping a different built-in function

from tf.keras.preprocessing.image import random_rotation

@tf.function
def rotate(image,label):
    return (random_rotation(image,90), label)

ds = ds.map(rotate)

I get the following error:

AttributeError: 'Tensor' object has no attribute 'ndim'

This is not the only function giving me issues, and it happens with or without the @tf.function decorator.

Any help is greatly appreciated!

1 Answer 1

2

I would try using tf.py_function in here for the random_rotation. For eg:

def rotate(image, label):
    im_shape = image.shape
    [image, label,] = tf.py_function(random_rotate,[image, label],
                                     [tf.float32, tf.string])
    image.set_shape(im_shape)
    return image, label

ds = ds.map(rotate)

Although I think they do similar things here according to What is the difference in purpose between tf.py_function and tf.function?, tf.py_function is more straightforward for executing python code through tensorflow even though tf.function has a performance advantage.

Sign up to request clarification or add additional context in comments.

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.