3
\$\begingroup\$

I am trying to implement Deep Quaternion Networks. I was able to implement the batch normalization technique. But it requires a lot of GPU memory. Is there any way I can optimize the code provided below?

 class MyQuaternionBatchNorm2d(torch.nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
        super(MyQuaternionBatchNorm2d, self).__init__()
        self.num_features = num_features
        self.qnum_features = num_features//4
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        
        if self.affine:
            self.weight = torch.nn.Parameter(torch.Tensor(self.qnum_features, 10))
            self.bias = torch.nn.Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
            
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(self.qnum_features,4))
            self.register_buffer('running_covar', torch.zeros(self.qnum_features,10))
            self.running_covar[:,0] = 1/ np.sqrt(4)
            self.running_covar[:,1] = 1/ np.sqrt(4)
            self.running_covar[:,2] = 1/ np.sqrt(4)
            self.running_covar[:,3] = 1/ np.sqrt(4)
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_buffer('running_mean',None)
            self.register_buffer('running_covar', None)
            self.register_parameter('num_batches_tracked', None)
        self.reset_parameters()
        
    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_covar.zero_()
            self.running_covar[:,0] = 1/ np.sqrt(4)
            self.running_covar[:,1] = 1/ np.sqrt(4)
            self.running_covar[:,2] = 1/ np.sqrt(4)
            self.running_covar[:,3] = 1/ np.sqrt(4)
            self.num_batches_tracked.zero_()

    def reset_parameters(self):
        self.reset_running_stats()
        
        if self.affine:
            torch.nn.init.zeros_(self.weight)
            torch.nn.init.constant_(self.weight[:,0], 1/ np.sqrt(4))
            torch.nn.init.constant_(self.weight[:,4], 1/ np.sqrt(4))
            torch.nn.init.constant_(self.weight[:,7], 1/ np.sqrt(4))
            torch.nn.init.constant_(self.weight[:,9], 1/ np.sqrt(4))
            torch.nn.init.zeros_(self.bias)
            
    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(input.dim()))
    
    @staticmethod
    def _decomposition_v1(r,i,j,k,Vrr, Vri, Vrj, Vrk, Vii, Vij, Vik, Vjj, Vjk, Vkk):
        Wrr = torch.sqrt(Vrr)
        Wri = (1.0 / Wrr) * (Vri)
        Wii = torch.sqrt((Vii - (Wri.pow(2))))
        Wrj = (1.0 / Wrr) * (Vrj)
        Wij = (1.0 / Wii) * (Vij - (Wri*Wrj))
        Wjj = torch.sqrt((Vjj - (Wij.pow(2) + Wrj.pow(2))))
        Wrk = (1.0 / Wrr) * (Vrk)
        Wik = (1.0 / Wii) * (Vik - (Wri*Wrk))
        Wjk = (1.0 / Wjj) * (Vjk - (Wij*Wik + Wrj*Wrk))
        Wkk = torch.sqrt((Vkk - (Wjk.pow(2) + Wik.pow(2) + Wrk.pow(2))))
        
        cat_W_1 = torch.cat([Wrr, Wri, Wrj, Wrk])
        cat_W_2 = torch.cat([Wri,Wii, Wij, Wik])
        cat_W_3 = torch.cat([Wrj, Wij, Wjj, Wjk])
        cat_W_4 = torch.cat([Wrk, Wik, Wjk, Wkk])
        
        output =  cat_W_1[None,:,None,None]  *  r.repeat(1,4,1,1) + cat_W_2[None,:,None,None] *   i.repeat(1,4,1,1) \
                    + cat_W_3[None,:,None,None]  *   j.repeat(1,4,1,1) +  cat_W_4[None,:,None,None]  *  k.repeat(1,4,1,1)

        return output
    
    def forward(self, input):
        self._check_input_dim(input)
        r,i,j,k = torch.chunk(input, 4, dim=1)
        
        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean_r, mean_i, mean_j, mean_k = r.mean([0, 2, 3]),i.mean([0, 2, 3]),j.mean([0, 2, 3]),k.mean([0, 2, 3])
            n = input.numel() / input.size(1)
            mean = torch.stack((mean_r, mean_i, mean_j, mean_k), dim=1)
            # update running mean
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean + (1 - exponential_average_factor) * self.running_mean
                    
            r = r-mean_r[None, :, None, None]
            i = i-mean_i[None, :, None, None]
            j = j-mean_j[None, :, None, None]
            k = k-mean_k[None, :, None, None]
            
            Vrr = (r.pow(2).mean([0, 2, 3])) + self.eps
            Vii = (i.pow(2).mean([0, 2, 3])) + self.eps
            Vjj = (j.pow(2).mean([0, 2, 3])) + self.eps
            Vkk = (k.pow(2).mean([0, 2, 3])) + self.eps
            Vri = ((r*i).mean([0, 2, 3]))
            Vrj = ((r*j).mean([0, 2, 3]))
            Vrk = ((r*k).mean([0, 2, 3]))
            Vij = ((i*j).mean([0, 2, 3]))
            Vik = ((i*k).mean([0, 2, 3]))
            Vjk = ((j*k).mean([0, 2, 3])) 

            with torch.no_grad():
                self.running_covar[:,0] = exponential_average_factor * Vrr * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,0]
                self.running_covar[:,1] = exponential_average_factor * Vii * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,1]
                self.running_covar[:,2] = exponential_average_factor * Vjj * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,2]
                self.running_covar[:,3] = exponential_average_factor * Vkk * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,3]
                
                self.running_covar[:,4] = exponential_average_factor * Vri * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,4]
                self.running_covar[:,5] = exponential_average_factor * Vrj * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,5]
                self.running_covar[:,6] = exponential_average_factor * Vrk * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,6]
                self.running_covar[:,7] = exponential_average_factor * Vij * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,7]
                self.running_covar[:,8] = exponential_average_factor * Vik * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,8]
                self.running_covar[:,9] = exponential_average_factor * Vjk * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,9]
        else:
            mean = self.running_mean
            Vrr = self.running_covar[:,0]+self.eps
            Vii = self.running_covar[:,1]+self.eps
            Vjj = self.running_covar[:,2]+self.eps
            Vkk = self.running_covar[:,3]+self.eps
            
            Vri = self.running_covar[:,4]+self.eps
            Vrj = self.running_covar[:,5]+self.eps
            Vrk = self.running_covar[:,6]+self.eps
            Vij = self.running_covar[:,7]+self.eps
            Vik = self.running_covar[:,8]+self.eps
            Vjk = self.running_covar[:,9]+self.eps
           
            r = r-mean[None,:,0,None,None]
            i = i-mean[None,:,1,None,None]
            j = j-mean[None,:,2,None,None]
            k = k-mean[None,:,3,None,None]
            
        # standardized_output
        input = self._decomposition_v1(r,i,j,k, Vrr, Vri, Vrj, Vrk, Vii, Vij, Vik, Vjj, Vjk, Vkk)
        
        if self.affine:
            r,i,j,k = torch.chunk(input, 4, dim=1)
            
            cat_gamma_1 = torch.cat([self.weight[:,0], self.weight[:,1], self.weight[:,2], self.weight[:,3]])
            cat_gamma_2 = torch.cat([self.weight[:,1], self.weight[:,4], self.weight[:,5], self.weight[:,6]])
            cat_gamma_3 = torch.cat([self.weight[:,2], self.weight[:,5], self.weight[:,7], self.weight[:,8]])
            cat_gamma_4 = torch.cat([self.weight[:,3], self.weight[:,6], self.weight[:,8], self.weight[:,9]])


            input =  cat_gamma_1[None,:,None,None] * r.repeat(1,4,1,1) \
                    + cat_gamma_2[None,:,None,None] * i.repeat(1,4,1,1) \
                    + cat_gamma_3[None,:,None,None] * j.repeat(1,4,1,1) \
                    + cat_gamma_4[None,:,None,None] * k.repeat(1,4,1,1) \
                    + self.bias[None, :, None, None]
        return input

I will explain the forward section. So the basic formula for batch normalization is x* = (x - E[x]) / sqrt(var(x)), where x* is the new value of a single component, E[x] is its mean within a batch and var(x) is its variance within a batch.

However, as it is quaternion batch normalization,it has 4 parts r which is the real part and i, j, and k, are the imaginary part.

The equation extends to x* = W(x - E[x]) / (var(x)). W is one of the matrices from the Cholesky decomposition of V^-1 where V is the variance.In the code,E[x] is computed using the mean variable. V is computed in Vxy and V^-1 i.e. W is computed in the _decomposition_v1 function. This is applied to the input.

Finally, that formula further extends to x** = gamma * x* + beta, where x** is the final normalized value. gamma i.e. cat_gamma_x and beta i.e. self.beta are learned per layer.

Note: The num_features need to be in multiples of 4.

Thank you, Shreyas

\$\endgroup\$
2
  • \$\begingroup\$ Please tell us more about what the code is actually supposed to accomplish. Can you summarize the goal of the code? Thank you. \$\endgroup\$ Commented Feb 19, 2020 at 21:17
  • 2
    \$\begingroup\$ I edited the question. Please take a look and let me know if you need more information. Thank you. \$\endgroup\$ Commented Feb 20, 2020 at 3:28

0

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.