-
Notifications
You must be signed in to change notification settings - Fork 260
Expand file tree
/
Copy pathparameters.py
More file actions
269 lines (233 loc) · 10 KB
/
parameters.py
File metadata and controls
269 lines (233 loc) · 10 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
# Copyright 2018 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Parameters for layer classes."""
import abc
from typing import Any, Dict
import tensorflow as tf
from tensorflow_compression.python.ops import math_ops
__all__ = [
"Parameter",
"RDFTParameter",
"GDNParameter",
]
class Parameter(tf.Module, metaclass=abc.ABCMeta):
"""Reparameterized `Layer` variable.
This object represents a parameter of a `tf.keras.layer.Layer` object which
isn't directly stored in a `tf.Variable`, but can be represented as a function
(of any number of `tf.Variable` attributes).
"""
@abc.abstractmethod
def __call__(self, compute_dtype=None):
"""Computes and returns the parameter value as a `tf.Tensor`."""
@abc.abstractmethod
def get_config(self):
"""Returns the configuration of the `Parameter`."""
return dict(name=self.name)
def get_weights(self):
return tf.keras.backend.batch_get_value(self.variables)
def set_weights(self, weights):
if len(weights) != len(self.variables):
raise ValueError(
f"set_weights() expects a list of {len(self.variables)} arrays, "
f"received {len(weights)}.")
tf.keras.backend.batch_set_value(zip(self.variables, weights))
def _parameter_conversion_func(value, dtype=None, name=None, as_ref=False):
del name # not supported
if as_ref:
raise ValueError("as_ref=True is not supported.")
return value(compute_dtype=dtype)
tf.register_tensor_conversion_function(
Parameter, _parameter_conversion_func,
)
@tf.keras.utils.register_keras_serializable(package="tensorflow_compression")
class RDFTParameter(Parameter):
"""RDFT reparameterization of a convolution kernel.
This uses the real-input discrete Fourier transform (RDFT) of a kernel as
its parameterization. The inverse RDFT is applied to the variable to produce
the kernel.
(see https://en.wikipedia.org/wiki/Discrete_Fourier_transform)
Attributes:
shape: `tf.TensorShape`. The shape of the convolution kernel.
real: `tf.Variable`. The real part of the RDFT of the kernel.
imag: `tf.Variable`. The imaginary part of the RDFT of the kernel.
"""
def __init__(self, initial_value, name=None, shape=None, dtype=None):
"""Initializer.
Args:
initial_value: `tf.Tensor` or `None`. The initial value of the kernel. If
not provided, its `shape` must be given, and the initial value of the
parameter will be undefined.
name: String. The name of the kernel.
shape: `tf.TensorShape` or compatible. Ignored unless `initial_value is
None`.
dtype: `tf.dtypes.DType` or compatible. DType of this parameter. If not
given, inferred from `initial_value`.
"""
super().__init__(name=name)
if initial_value is None:
if shape is None:
raise ValueError("If initial_value is None, shape must be specified.")
initial_value = tf.zeros(shape, dtype=dtype)
else:
initial_value = tf.convert_to_tensor(initial_value, dtype=dtype)
self._shape = initial_value.shape
self._dtype = initial_value.dtype
if self.shape.rank == 3:
initial_value = tf.transpose(initial_value, (1, 2, 0))
initial_value = tf.signal.rfft(initial_value)
elif self.shape.rank == 4:
initial_value = tf.transpose(initial_value, (2, 3, 0, 1))
initial_value = tf.signal.rfft2d(initial_value)
elif self.shape.rank == 5:
initial_value = tf.transpose(initial_value, (3, 4, 0, 1, 2))
initial_value = tf.signal.rfft3d(initial_value)
else:
raise ValueError(
f"Expected kernel tensor of rank 3, 4, or 5; received shape "
f"{self._shape}.")
norm = tf.constant(
self.shape[:-2].num_elements() ** .5, initial_value.dtype)
initial_value /= norm
# We split the variable into real and imaginary parts to avoid issues with
# complex-valued variables being unsupported when saving models, etc.
real = tf.math.real(initial_value)
imag = tf.math.imag(initial_value)
real_name = imag_name = None
if name is not None:
real_name = f"{name}_real"
imag_name = f"{name}_imag"
self.real = tf.Variable(real, name=real_name)
self.imag = tf.Variable(imag, name=imag_name)
@property
def dtype(self) -> tf.dtypes.DType:
return self._dtype
@property
def shape(self) -> tf.TensorShape:
return self._shape
@tf.Module.with_name_scope
def __call__(self, compute_dtype=None) -> tf.Tensor:
"""Computes and returns the convolution kernel as a `tf.Tensor`."""
real, imag = self.real, self.imag
if compute_dtype in (tf.bfloat16, tf.float16):
# As of 2022-02, there is no half precision complex math in TensorFlow.
# So, we need to use at least 32 bits.
real = tf.cast(real, tf.float32)
imag = tf.cast(imag, tf.float32)
elif compute_dtype is not None:
real = tf.cast(real, compute_dtype)
imag = tf.cast(imag, compute_dtype)
rdft = tf.dtypes.complex(real, imag)
norm = tf.constant(self.shape[:-2].num_elements() ** .5, rdft.dtype)
rdft *= norm
if self.shape.rank == 3:
kernel = tf.signal.irfft(rdft, fft_length=self.shape[:-2])
kernel = tf.transpose(kernel, (2, 0, 1))
elif self.shape.rank == 4:
kernel = tf.signal.irfft2d(rdft, fft_length=self.shape[:-2])
kernel = tf.transpose(kernel, (2, 3, 0, 1))
else:
assert self.shape.rank == 5, self.shape
kernel = tf.signal.irfft3d(rdft, fft_length=self.shape[:-2])
kernel = tf.transpose(kernel, (2, 3, 4, 0, 1))
if compute_dtype is not None:
# If we had to bump up precision to 32 bits, finally cast to compute_dtype
# here. In other cases, this should be a no-op.
kernel = tf.cast(kernel, compute_dtype)
return kernel
def get_config(self) -> Dict[str, Any]:
config = super().get_config()
config.update(
initial_value=None,
shape=tuple(map(int, self.shape)),
dtype=self.dtype.name,
)
return config
@tf.keras.utils.register_keras_serializable(package="tensorflow_compression")
class GDNParameter(Parameter):
"""Nonnegative parameterization as needed for GDN parameters.
The variable is subjected to an invertible transformation that slows down the
learning rate for small values.
Attributes:
minimum: Float. The `minimum` parameter provided on initialization.
offset: Float. The `offset` parameter provided on initialization.
variable: `tf.Variable`. The reparameterized variable.
"""
def __init__(self, initial_value, name=None, minimum=0., offset=2 ** -18,
shape=None, dtype=None):
"""Initializer.
Args:
initial_value: `tf.Tensor` or `None`. The initial value of the kernel. If
not provided, its `shape` must be given, and the initial value of the
parameter will be undefined.
name: String. The name of the parameter.
minimum: Float. Lower bound for the parameter (defaults to zero).
offset: Float. Offset added to the reparameterization. The
parameterization of beta/gamma as their square roots lets the training
slow down when values are close to zero, which is desirable as small
values in the denominator can lead to a situation where gradient noise
on beta/gamma leads to extreme amounts of noise in the GDN activations.
However, without the offset, we would get zero gradients if any elements
of beta or gamma were exactly zero, and thus the training could get
stuck. To prevent this, we add this small constant. The default value
was empirically determined as a good starting point. Making it bigger
potentially leads to more gradient noise on the activations, making it
too small may lead to numerical precision issues.
shape: `tf.TensorShape` or compatible. Ignored unless `initial_value is
None`.
dtype: `tf.dtypes.DType` or compatible. DType of this parameter. If not
given, inferred from `initial_value`.
"""
super().__init__(name=name)
self._minimum = float(minimum)
self._offset = float(offset)
if initial_value is None:
if shape is None:
raise ValueError("If initial_value is None, shape must be specified.")
initial_value = tf.zeros(shape, dtype=dtype)
else:
initial_value = tf.convert_to_tensor(initial_value, dtype=dtype)
pedestal = tf.constant(self.offset ** 2, dtype=initial_value.dtype)
initial_value = tf.math.sqrt(
tf.math.maximum(initial_value + pedestal, pedestal))
if name is not None:
name = f"reparam_{name}"
self.variable = tf.Variable(initial_value, name=name)
@tf.Module.with_name_scope
def __call__(self, compute_dtype=None) -> tf.Tensor:
"""Computes and returns the non-negative value as a `tf.Tensor`."""
variable = self.variable
if compute_dtype is not None:
variable = tf.cast(variable, compute_dtype)
pedestal = tf.constant(self.offset ** 2, dtype=variable.dtype)
bound = tf.constant(
(self.minimum + self.offset ** 2) ** .5, dtype=variable.dtype)
reparam_value = math_ops.lower_bound(variable, bound)
return tf.math.square(reparam_value) - pedestal
@property
def minimum(self) -> float:
return self._minimum
@property
def offset(self) -> float:
return self._offset
def get_config(self) -> Dict[str, Any]:
config = super().get_config()
config.update(
initial_value=None,
minimum=self.minimum,
offset=self.offset,
shape=tuple(map(int, self.variable.shape)),
dtype=self.variable.dtype.name,
)
return config