Before we dive in, let's make sure we're using a GPU for this demo.
To do this, select "Runtime" -> "Change runtime type" -> "Hardware accelerator" -> "GPU".
The following snippet will verify that we have access to a GPU.
if tf.test.gpu_device_name() != '/device:GPU:0':
print('WARNING: GPU device not found.')
else:
print('SUCCESS: Found GPU: {}'.format(tf.test.gpu_device_name()))
WARNING: GPU device not found.
Motivation
Wouldn't it be great if we could use TFP to specify a probabilistic model then simply minimize the negative log-likelihood, i.e.,
negloglik = lambda y, rv_y: -rv_y.log_prob(y)
Well not only is it possible, but this colab shows how! (In context of linear regression problems.)
# Specify the surrogate posterior over `keras.layers.Dense` `kernel` and `bias`.defposterior_mean_field(kernel_size,bias_size=0,dtype=None):n=kernel_size+bias_sizec=np.log(np.expm1(1.))returntf_keras.Sequential([tfp.layers.VariableLayer(2*n,dtype=dtype),tfp.layers.DistributionLambda(lambdat:tfd.Independent(tfd.Normal(loc=t[...,:n],scale=1e-5+tf.nn.softplus(c+t[...,n:])),reinterpreted_batch_ndims=1)),])
# Specify the prior over `keras.layers.Dense` `kernel` and `bias`.defprior_trainable(kernel_size,bias_size=0,dtype=None):n=kernel_size+bias_sizereturntf_keras.Sequential([tfp.layers.VariableLayer(n,dtype=dtype),tfp.layers.DistributionLambda(lambdat:tfd.Independent(tfd.Normal(loc=t,scale=1),reinterpreted_batch_ndims=1)),])
plt.figure(figsize=[6,1.5])# inchesplt.plot(x,y,'b.',label='observed');yhats=[model(x_tst)for_inrange(100)]avgm=np.zeros_like(x_tst[...,0])fori,yhatinenumerate(yhats):m=np.squeeze(yhat.mean())s=np.squeeze(yhat.stddev())ifi < 15:plt.plot(x_tst,m,'r',label='ensemble means'ifi==0elseNone,linewidth=1.)plt.plot(x_tst,m+2*s,'g',linewidth=0.5,label='ensemble means + 2 ensemble stdev'ifi==0elseNone);plt.plot(x_tst,m-2*s,'g',linewidth=0.5,label='ensemble means - 2 ensemble stdev'ifi==0elseNone);avgm+=mplt.plot(x_tst,avgm/len(yhats),'r',label='overall mean',linewidth=4)plt.ylim(-0.,17);plt.yticks(np.linspace(0,15,4)[1:]);plt.xticks(np.linspace(*x_range,num=9));ax=plt.gca();ax.xaxis.set_ticks_position('bottom')ax.yaxis.set_ticks_position('left')ax.spines['left'].set_position(('data',0))ax.spines['top'].set_visible(False)ax.spines['right'].set_visible(False)#ax.spines['left'].set_smart_bounds(True)#ax.spines['bottom'].set_smart_bounds(True)plt.legend(loc='center left',fancybox=True,framealpha=0.,bbox_to_anchor=(1.05,0.5))plt.savefig('/tmp/fig4.png',bbox_inches='tight',dpi=300)
Case 5: Functional Uncertainty
Custom PSD Kernel
classRBFKernelFn(tf_keras.layers.Layer):def__init__(self,**kwargs):super(RBFKernelFn,self).__init__(**kwargs)dtype=kwargs.get('dtype',None)self._amplitude=self.add_variable(initializer=tf.constant_initializer(0),dtype=dtype,name='amplitude')self._length_scale=self.add_variable(initializer=tf.constant_initializer(0),dtype=dtype,name='length_scale')defcall(self,x):# Never called -- this is just a layer so it can hold variables# in a way Keras understands.returnx@propertydefkernel(self):returntfp.math.psd_kernels.ExponentiatedQuadratic(amplitude=tf.nn.softplus(0.1*self._amplitude),length_scale=tf.nn.softplus(5.*self._length_scale))
# For numeric stability, set the default floating-point dtype to float64tf_keras.backend.set_floatx('float64')# Build model.num_inducing_points=40model=tf_keras.Sequential([tf_keras.layers.InputLayer(input_shape=[1]),tf_keras.layers.Dense(1,kernel_initializer='ones',use_bias=False),tfp.layers.VariationalGaussianProcess(num_inducing_points=num_inducing_points,kernel_provider=RBFKernelFn(),event_shape=[1],inducing_index_points_initializer=tf.constant_initializer(np.linspace(*x_range,num=num_inducing_points,dtype=x.dtype)[...,np.newaxis]),unconstrained_observation_noise_variance_initializer=(tf.constant_initializer(np.array(0.54).astype(x.dtype))),),])# Do inference.batch_size=32loss=lambday,rv_y:rv_y.variational_loss(y,kl_weight=np.array(batch_size,x.dtype)/x.shape[0])model.compile(optimizer=tf_keras.optimizers.Adam(learning_rate=0.01),loss=loss)model.fit(x,y,batch_size=batch_size,epochs=1000,verbose=False)# Profit.yhat=model(x_tst)assertisinstance(yhat,tfd.Distribution)
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-02-22 UTC."],[],[]]