-
Notifications
You must be signed in to change notification settings - Fork 260
Expand file tree
/
Copy pathdeep_factorized.py
More file actions
267 lines (217 loc) · 9.43 KB
/
deep_factorized.py
File metadata and controls
267 lines (217 loc) · 9.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Deep fully factorized distribution based on cumulative."""
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_compression.python.distributions import helpers
from tensorflow_compression.python.distributions import uniform_noise
__all__ = [
"DeepFactorized",
"NoisyDeepFactorized",
]
def log_expm1(x):
"""Computes log(exp(x)-1) stably.
For large values of x, exp(x) will return Inf whereas log(exp(x)-1) ~= x.
Here we use this approximation for x>15, such that the output is non-Inf for
all positive values x.
Args:
x: A tensor.
Returns:
log(exp(x)-1)
"""
# If x<15.0, we can compute it directly. For larger values,
# we have log(exp(x)-1) ~= log(exp(x)) = x.
cond = (x < 15.0)
x_small = tf.minimum(x, 15.0)
return tf.where(cond, tf.math.log(tf.math.expm1(x_small)), x)
class DeepFactorized(tfp.distributions.Distribution):
"""Fully factorized distribution based on neural network cumulative.
This is a flexible, nonparametric probability density model, described in
appendix 6.1 of the paper:
> "Variational image compression with a scale hyperprior"<br />
> J. Ballé, D. Minnen, S. Singh, S. J. Hwang, N. Johnston<br />
> https://openreview.net/forum?id=rkcQFMZRb
but *without* convolution with a unit-width uniform
density, as described in appendix 6.2 of the same paper. Please cite the paper
if you use this code for scientific work.
This is a scalar distribution (i.e., its `event_shape` is always length 0),
and the density object always creates its own `tf.Variable`s representing the
trainable distribution parameters.
"""
def __init__(self,
batch_shape=(), num_filters=(3, 3), init_scale=10,
allow_nan_stats=False, dtype=tf.float32, name="DeepFactorized"):
"""Initializer.
Args:
batch_shape: Iterable of integers. The desired batch shape for the
`Distribution` (rightmost dimensions which are assumed independent, but
not identically distributed).
num_filters: Iterable of integers. The number of filters for each of the
hidden layers. The first and last layer of the network implementing the
cumulative distribution are not included (they are assumed to be 1).
init_scale: Float. Scale factor for the density at initialization. It is
recommended to choose a large enough scale factor such that most values
initially lie within a region of high likelihood. This improves
training.
allow_nan_stats: Boolean. Whether to allow `NaN`s to be returned when
querying distribution statistics.
dtype: A floating point `tf.dtypes.DType`. Computations relating to this
distribution will be performed at this precision.
name: String. A name for this distribution.
"""
parameters = dict(locals())
self._batch_shape_tuple = tuple(int(s) for s in batch_shape)
self._num_filters = tuple(int(f) for f in num_filters)
self._init_scale = float(init_scale)
super().__init__(
dtype=dtype,
reparameterization_type=tfp.distributions.NOT_REPARAMETERIZED,
validate_args=False,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
name=name,
)
with self.name_scope:
self._make_variables()
@property
def num_filters(self):
return self._num_filters
@property
def init_scale(self):
return self._init_scale
def _make_variables(self):
"""Creates the variables representing the parameters of the distribution."""
channels = self.batch_shape.num_elements()
filters = (1,) + self.num_filters + (1,)
scale = self.init_scale ** (1 / (len(self.num_filters) + 1))
self._matrices = []
self._biases = []
self._factors = []
for i in range(len(self.num_filters) + 1):
def matrix_initializer(i=i):
init = log_expm1(1 / scale / filters[i + 1])
init = tf.cast(init, dtype=self.dtype)
init = tf.broadcast_to(init, (channels, filters[i + 1], filters[i]))
return init
matrix = tf.Variable(matrix_initializer, name="matrix_{}".format(i))
self._matrices.append(matrix)
def bias_initializer(i=i):
return tf.random.uniform(
(channels, filters[i + 1], 1), -.5, .5, dtype=self.dtype)
bias = tf.Variable(bias_initializer, name="bias_{}".format(i))
self._biases.append(bias)
if i < len(self.num_filters):
def factor_initializer(i=i):
return tf.zeros((channels, filters[i + 1], 1), dtype=self.dtype)
factor = tf.Variable(factor_initializer, name="factor_{}".format(i))
self._factors.append(factor)
def _batch_shape_tensor(self):
return tf.constant(self._batch_shape_tuple, dtype=int)
def _batch_shape(self):
return tf.TensorShape(self._batch_shape_tuple)
def _event_shape_tensor(self):
return tf.constant((), dtype=int)
def _event_shape(self):
return tf.TensorShape(())
def _broadcast_inputs(self, inputs):
shape = tf.broadcast_dynamic_shape(
tf.shape(inputs), self.batch_shape_tensor())
return tf.broadcast_to(inputs, shape)
def _logits_cumulative(self, inputs):
"""Evaluate logits of the cumulative densities.
Args:
inputs: The values at which to evaluate the cumulative densities.
Returns:
A `tf.Tensor` of the same shape as `inputs`, containing the logits of the
cumulative densities evaluated at the given inputs.
"""
# Convert to (channels, 1, batch) format by collapsing dimensions and then
# commuting channels to front.
shape = tf.shape(inputs)
inputs = tf.reshape(inputs, (-1, 1, self.batch_shape.num_elements()))
inputs = tf.transpose(inputs, (2, 1, 0))
logits = inputs
for i in range(len(self.num_filters) + 1):
matrix = tf.nn.softplus(self._matrices[i])
logits = tf.linalg.matmul(matrix, logits)
logits += self._biases[i]
if i < len(self.num_filters):
factor = tf.math.tanh(self._factors[i])
logits += factor * tf.math.tanh(logits)
# Convert back to (broadcasted) input tensor shape.
logits = tf.transpose(logits, (2, 1, 0))
logits = tf.reshape(logits, shape)
return logits
def _log_cdf(self, inputs):
inputs = self._broadcast_inputs(inputs)
logits = self._logits_cumulative(inputs)
return tf.math.log_sigmoid(logits)
def _log_survival_function(self, inputs):
inputs = self._broadcast_inputs(inputs)
logits = self._logits_cumulative(inputs)
# 1-sigmoid(x) = sigmoid(-x)
return tf.math.log_sigmoid(-logits)
def _cdf(self, inputs):
inputs = self._broadcast_inputs(inputs)
logits = self._logits_cumulative(inputs)
return tf.math.sigmoid(logits)
def _survival_function(self, inputs):
inputs = self._broadcast_inputs(inputs)
logits = self._logits_cumulative(inputs)
# 1-sigmoid(x) = sigmoid(-x)
return tf.math.sigmoid(-logits)
def _prob(self, inputs):
inputs = self._broadcast_inputs(inputs)
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(inputs)
cdf = self._cdf(inputs)
prob = tape.gradient(cdf, inputs)
return prob
def _log_prob(self, inputs):
inputs = self._broadcast_inputs(inputs)
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(inputs)
logits = self._logits_cumulative(inputs)
# Let x=inputs and s(x)=sigmoid(x).
# We have F(x) = s(logits(x)),
# so p(x) = F'(x)
# = s'(logits(x)) * logits'(x)
# = s(logits(x))*s(-logits(x)) * logits'(x)
# so log p(x) = log(s(logits(x)) + log(s(-logits(x)) + log(logits'(x)).
log_s_logits = tf.math.log_sigmoid(logits)
log_s_neg_logits = tf.math.log_sigmoid(-logits)
dlogits = tape.gradient(logits, inputs)
return log_s_logits + log_s_neg_logits + tf.math.log(dlogits)
def _quantization_offset(self):
return helpers.estimate_tails(
self._logits_cumulative, 0., self.batch_shape_tensor(), self.dtype)
def _lower_tail(self, tail_mass):
logits = tf.math.log(
tf.cast(tail_mass / 2 / (1. - tail_mass / 2), self.dtype))
return helpers.estimate_tails(
self._logits_cumulative, logits, self.batch_shape_tensor(), self.dtype)
def _upper_tail(self, tail_mass):
logits = -tf.math.log(
tf.cast(tail_mass / 2 / (1. - tail_mass / 2), self.dtype))
return helpers.estimate_tails(
self._logits_cumulative, logits, self.batch_shape_tensor(), self.dtype)
@classmethod
def _parameter_properties(cls, dtype=tf.float32, num_classes=None):
raise NotImplementedError(
f"`{cls.__name__}` does not implement `_parameter_properties`.")
class NoisyDeepFactorized(uniform_noise.UniformNoiseAdapter):
"""`DeepFactorized` that is convolved with uniform noise."""
def __init__(self, name="NoisyDeepFactorized", **kwargs):
super().__init__(DeepFactorized(**kwargs), name=name)