I'm trying to integrate the TFDWT library's DWT3D layer into a custom Keras layer.
import tensorflow as tf
import keras
from TFDWT.DWT3DFB import DWT3D
class DWTLayer(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dwt3d = DWT3D(wave='haar')
def call(self, x):
#x = tf.transpose(x, [0, 2, 3, 1, 4])
return self.dwt3d(x)
class SimpleModel(keras.Model):
def __init__(self):
super().__init__()
self.dwt_layer = DWTLayer()
self.conv = keras.layers.Conv3D(1, 3, padding='same')
def call(self, x):
x = self.dwt_layer(x)
return self.conv(x)
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = SimpleModel()
model.compile(optimizer='adam', loss='mse')
with strategy.scope():
dummy_input = tf.random.normal((4, 128, 128, 128, 4))
try:
_ = model(dummy_input, training=False)
print("Success")
except TypeError as e:
print(f" Error: {e}")
Error:
Error: Exception encountered when calling DWT3D.call().
<tf.Tensor 'dwt3d/concat_2:0' shape=(128, 128) dtype=float32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.
Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.
<tf.Tensor 'dwt3d/concat_2:0' shape=(128, 128) dtype=float32> was defined here:
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/usr/local/lib/python3.12/dist-packages/colab_kernel_launcher.py", line 37, in <module>
File "/usr/local/lib/python3.12/dist-packages/traitlets/config/application.py", line 992, in launch_instance
File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelapp.py", line 712, in start
File "/usr/local/lib/python3.12/dist-packages/tornado/platform/asyncio.py", line 211, in start
File "/usr/lib/python3.12/asyncio/base_events.py", line 645, in run_forever
File "/usr/lib/python3.12/asyncio/base_events.py", line 1999, in _run_once
File "/usr/lib/python3.12/asyncio/events.py", line 88, in _run
File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelbase.py", line 499, in process_one
File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelbase.py", line 730, in execute_request
File "/usr/local/lib/python3.12/dist-packages/ipykernel/ipkernel.py", line 383, in do_execute
File "/usr/local/lib/python3.12/dist-packages/ipykernel/zmqshell.py", line 528, in run_cell
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell
File "/usr/local/lib/python3.12/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
File "/tmp/ipython-input-3621015639.py", line 33, in <cell line: 0>
File "/usr/local/lib/python3.12/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/usr/local/lib/python3.12/dist-packages/keras/src/layers/layer.py", line 878, in __call__
File "/usr/local/lib/python3.12/dist-packages/keras/src/layers/layer.py", line 1498, in _maybe_build
File "/usr/local/lib/python3.12/dist-packages/keras/src/backend/tensorflow/core.py", line 240, in compute_output_spec
File "/tmp/ipython-input-3621015639.py", line 21, in call
File "/usr/local/lib/python3.12/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/usr/local/lib/python3.12/dist-packages/keras/src/layers/layer.py", line 878, in __call__
File "/usr/local/lib/python3.12/dist-packages/keras/src/layers/layer.py", line 1498, in _maybe_build
File "/usr/local/lib/python3.12/dist-packages/keras/src/backend/tensorflow/core.py", line 240, in compute_output_spec
File "/tmp/ipython-input-3621015639.py", line 12, in call
File "/usr/local/lib/python3.12/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/usr/local/lib/python3.12/dist-packages/keras/src/layers/layer.py", line 878, in __call__
File "/usr/local/lib/python3.12/dist-packages/keras/src/layers/layer.py", line 1489, in _maybe_build
File "/usr/local/lib/python3.12/dist-packages/keras/src/layers/layer.py", line 232, in build_wrapper
File "/usr/local/lib/python3.12/dist-packages/TFDWT/DWT3DFB.py", line 34, in build
File "/usr/local/lib/python3.12/dist-packages/TFDWT/DWTFBlayout.py", line 42, in build
File "/usr/local/lib/python3.12/dist-packages/TFDWT/dwt_op.py", line 38, in make_dwt_operator_matrix_A
File "/usr/local/lib/python3.12/dist-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
File "/usr/local/lib/python3.12/dist-packages/tensorflow/python/util/dispatch.py", line 1260, in op_dispatch_handler
File "/usr/local/lib/python3.12/dist-packages/tensorflow/python/ops/array_ops.py", line 1441, in concat
File "/usr/local/lib/python3.12/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 1316, in concat_v2
File "/usr/local/lib/python3.12/dist-packages/tensorflow/python/framework/op_def_library.py", line 796, in _apply_op_helper
File "/usr/local/lib/python3.12/dist-packages/tensorflow/python/framework/func_graph.py", line 614, in _create_op_internal
File "/usr/local/lib/python3.12/dist-packages/tensorflow/python/framework/ops.py", line 2705, in _create_op_internal
File "/usr/local/lib/python3.12/dist-packages/tensorflow/python/framework/ops.py", line 1200, in from_node_def
The tensor <tf.Tensor 'dwt3d/concat_2:0' shape=(128, 128) dtype=float32> cannot be accessed from here, because it was defined in FuncGraph(name=scratch_graph_1, id=132083996496448), which is out of scope.
Arguments received by DWT3D.call():
• inputs=tf.Tensor(shape=(4, 128, 128, 128, 4), dtype=float32)
I found a solution:
import tensorflow as tf
import keras
from TFDWT.DWT3DFB import DWT3D
class DWTLayer(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dwt3d = DWT3D(wave='haar')
self.dwt3d_built = False
def call(self, x):
#x = tf.transpose(x, [0, 2, 3, 1, 4])
if not self.dwt3d_built:
with tf.init_scope():
self.dwt3d.build(x.shape)
self.dwt3d_built = True
# Use the initialized instance
dwt_coef = self.dwt3d(x) # Works!
return dwt_coef
class SimpleModel(keras.Model):
def __init__(self):
super().__init__()
self.dwt_layer = DWTLayer()
self.conv = keras.layers.Conv3D(1, 3, padding='same')
def call(self, x):
x = self.dwt_layer(x)
return self.conv(x)
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = SimpleModel()
model.compile(optimizer='adam', loss='mse')
with strategy.scope():
dummy_input = tf.random.normal((4, 128, 128, 128, 4))
try:
_ = model(dummy_input, training=False)
print("Success")
except TypeError as e:
print(f" Error: {e}")
So my questions are:
Why the DWT layer fails? and Is there is a way to use it normally like other layers without the workaround solution I provided?
Is this a safe solution to use? I mean can this solution affects model performance (accuracy) and speed (training speed)?