3

I'm trying to integrate the TFDWT library's DWT3D layer into a custom Keras layer.

import tensorflow as tf
import keras
from TFDWT.DWT3DFB import DWT3D

class DWTLayer(keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dwt3d = DWT3D(wave='haar')
    
    def call(self, x):
        #x = tf.transpose(x, [0, 2, 3, 1, 4])
        return self.dwt3d(x)

class SimpleModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.dwt_layer = DWTLayer()
        self.conv = keras.layers.Conv3D(1, 3, padding='same')
    
    def call(self, x):
        x = self.dwt_layer(x)
        return self.conv(x)

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    model = SimpleModel()
    model.compile(optimizer='adam', loss='mse')

with strategy.scope():
    dummy_input = tf.random.normal((4, 128, 128, 128, 4))
    try:
        _ = model(dummy_input, training=False)
        print("Success")
    except TypeError as e:
        print(f" Error: {e}")

Error:

 Error: Exception encountered when calling DWT3D.call().

<tf.Tensor 'dwt3d/concat_2:0' shape=(128, 128) dtype=float32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.
Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

<tf.Tensor 'dwt3d/concat_2:0' shape=(128, 128) dtype=float32> was defined here:
    File "<frozen runpy>", line 198, in _run_module_as_main
    File "<frozen runpy>", line 88, in _run_code
    File "/usr/local/lib/python3.12/dist-packages/colab_kernel_launcher.py", line 37, in <module>
    File "/usr/local/lib/python3.12/dist-packages/traitlets/config/application.py", line 992, in launch_instance
    File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelapp.py", line 712, in start
    File "/usr/local/lib/python3.12/dist-packages/tornado/platform/asyncio.py", line 211, in start
    File "/usr/lib/python3.12/asyncio/base_events.py", line 645, in run_forever
    File "/usr/lib/python3.12/asyncio/base_events.py", line 1999, in _run_once
    File "/usr/lib/python3.12/asyncio/events.py", line 88, in _run
    File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
    File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelbase.py", line 499, in process_one
    File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
    File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelbase.py", line 730, in execute_request
    File "/usr/local/lib/python3.12/dist-packages/ipykernel/ipkernel.py", line 383, in do_execute
    File "/usr/local/lib/python3.12/dist-packages/ipykernel/zmqshell.py", line 528, in run_cell
    File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
    File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell
    File "/usr/local/lib/python3.12/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
    File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
    File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
    File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    File "/tmp/ipython-input-3621015639.py", line 33, in <cell line: 0>
    File "/usr/local/lib/python3.12/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
    File "/usr/local/lib/python3.12/dist-packages/keras/src/layers/layer.py", line 878, in __call__
    File "/usr/local/lib/python3.12/dist-packages/keras/src/layers/layer.py", line 1498, in _maybe_build
    File "/usr/local/lib/python3.12/dist-packages/keras/src/backend/tensorflow/core.py", line 240, in compute_output_spec
    File "/tmp/ipython-input-3621015639.py", line 21, in call
    File "/usr/local/lib/python3.12/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
    File "/usr/local/lib/python3.12/dist-packages/keras/src/layers/layer.py", line 878, in __call__
    File "/usr/local/lib/python3.12/dist-packages/keras/src/layers/layer.py", line 1498, in _maybe_build
    File "/usr/local/lib/python3.12/dist-packages/keras/src/backend/tensorflow/core.py", line 240, in compute_output_spec
    File "/tmp/ipython-input-3621015639.py", line 12, in call
    File "/usr/local/lib/python3.12/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
    File "/usr/local/lib/python3.12/dist-packages/keras/src/layers/layer.py", line 878, in __call__
    File "/usr/local/lib/python3.12/dist-packages/keras/src/layers/layer.py", line 1489, in _maybe_build
    File "/usr/local/lib/python3.12/dist-packages/keras/src/layers/layer.py", line 232, in build_wrapper
    File "/usr/local/lib/python3.12/dist-packages/TFDWT/DWT3DFB.py", line 34, in build
    File "/usr/local/lib/python3.12/dist-packages/TFDWT/DWTFBlayout.py", line 42, in build
    File "/usr/local/lib/python3.12/dist-packages/TFDWT/dwt_op.py", line 38, in make_dwt_operator_matrix_A
    File "/usr/local/lib/python3.12/dist-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
    File "/usr/local/lib/python3.12/dist-packages/tensorflow/python/util/dispatch.py", line 1260, in op_dispatch_handler
    File "/usr/local/lib/python3.12/dist-packages/tensorflow/python/ops/array_ops.py", line 1441, in concat
    File "/usr/local/lib/python3.12/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 1316, in concat_v2
    File "/usr/local/lib/python3.12/dist-packages/tensorflow/python/framework/op_def_library.py", line 796, in _apply_op_helper
    File "/usr/local/lib/python3.12/dist-packages/tensorflow/python/framework/func_graph.py", line 614, in _create_op_internal
    File "/usr/local/lib/python3.12/dist-packages/tensorflow/python/framework/ops.py", line 2705, in _create_op_internal
    File "/usr/local/lib/python3.12/dist-packages/tensorflow/python/framework/ops.py", line 1200, in from_node_def

The tensor <tf.Tensor 'dwt3d/concat_2:0' shape=(128, 128) dtype=float32> cannot be accessed from here, because it was defined in FuncGraph(name=scratch_graph_1, id=132083996496448), which is out of scope.

Arguments received by DWT3D.call():
  • inputs=tf.Tensor(shape=(4, 128, 128, 128, 4), dtype=float32)

I found a solution:

import tensorflow as tf
import keras
from TFDWT.DWT3DFB import DWT3D

class DWTLayer(keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dwt3d = DWT3D(wave='haar')
        self.dwt3d_built = False
    
    def call(self, x):
        #x = tf.transpose(x, [0, 2, 3, 1, 4])
        if not self.dwt3d_built:
            with tf.init_scope():
                self.dwt3d.build(x.shape)
            self.dwt3d_built = True
        
        # Use the initialized instance
        dwt_coef = self.dwt3d(x)  # Works!
        return dwt_coef

class SimpleModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.dwt_layer = DWTLayer()
        self.conv = keras.layers.Conv3D(1, 3, padding='same')
    
    def call(self, x):
        x = self.dwt_layer(x)
        return self.conv(x)

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    model = SimpleModel()
    model.compile(optimizer='adam', loss='mse')

with strategy.scope():
    dummy_input = tf.random.normal((4, 128, 128, 128, 4))
    try:
        _ = model(dummy_input, training=False)
        print("Success")
    except TypeError as e:
        print(f" Error: {e}")

So my questions are:

  1. Why the DWT layer fails? and Is there is a way to use it normally like other layers without the workaround solution I provided?

  2. Is this a safe solution to use? I mean can this solution affects model performance (accuracy) and speed (training speed)?

1 Answer 1

2

It failed because there's a conflict between Keras, which automatically inferences shapes, and how TFDWT creates tensors.

When you compile the model, Keras tries to determine the shapes of the tensors in the network, so it traces the layer with a scratch graph. It calls self.dwt3d(x), which triggers dwt3d.build(), which then creates the tensor in that graph. When Keras finishes determining the shapes, it deletes this scratch graph. When the training starts, your DWT3D layer tries to use that matrix, but it can't because it belongs to a graph that was deleted.

Your solution will work, but here is a cleaner way to do it:

class DWTLayer(keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dwt3d = DWT3D(wave='haar')

    def build(self, input_shape):
        with tf.init_scope():
            self.dwt3d.build(input_shape)
        super().build(input_shape)
    
    def call(self, x):
        return self.dwt3d(x)
Sign up to request clarification or add additional context in comments.

1 Comment

Thank you so much! , so is this approach is safe if I use the output of self.dwt3d(x) for a downstream operations like a conv layer in the call( i.e def call(self, x): return self.conv1(self.dwt3d(x)) # assume self.conv1 is created correctly in init) ? I mean in this case with tf.init_scope(): self.dwt3d.build(input_shape) this will not cause any computational speed or model performance issues ?

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.