Hessian Computation using TensorFlow
Compute Hessian using TensorFlow
Reference:
- TensorFlow implementation of K-means algorithm: https://codesachin.wordpress.com/2015/11/14/k-means-clustering-with-tensorflow/
- Calculating Hessian in Theano (application for Newton’s method): https://groups.google.com/forum/#!topic/theano-users/2c15kq68lp8
TensorFlow has a function called tf.gradients() that computes gradient. In the past, I’ve tried to compute Hessian of an neural network objective function in Torch7 using torch-autograd but it was somewhat cumbersome; there wasn’t an easy way to store/reshape parameters because Lua uses table for everything. Today, I’d like to do the same thing in TensorFlow. It should be much easier than in Torch7 due to the symbolic differentiation.
Example 1 : Quadratic function
We are going to use $f(x) = \frac{1}{2} x^T A x + b^T x + c$ as our first example to compute Hessian. When A is a symmetric matrix, the hessian of $f$ should be equal to $A$.
For simplicity, let us start with:
$$A = \left[
\begin{array}{rrr}
2 & 2 & 2 \\
2 & 2 & 2 \\
2 & 2 & 2
\end{array}
\right]
\quad
b = \left[
\begin{array}{rrr}
3 \\
3 \\
3
\end{array}
\right]
\quad
c = 1$$
The code below calculates the hessian for f(x).
|
|
|
|
|
|
[[ 2. 2. 2.]
[ 2. 2. 2.]
[ 2. 2. 2.]]
We can see that the result of sess.run(hess, feed_dict) is indeed the desired value: A
Example 2 : Multilayer Perceptron
Next, we’ll try a small neural network model: Multilayer perceptron. We need to modify our getHessian function a little bit; we need to create one-long vector for parameters, and then slice them according to the model architecture. Otherwise tf.gradients() cannot calculate the hessian matrix.
|
|
|
|
(31, 31)
[[ 2.19931314e-03 -1.42659002e-03 1.30957202e-03 -8.70158256e-04
8.50890204e-03 -5.51932165e-03 5.06659178e-03 -3.36654740e-03
7.25943921e-03 -4.70885402e-03 4.32260381e-03 -2.87219742e-03
1.87662840e-02 -1.21727986e-02 1.11743081e-02 -7.42488075e-03
-5.90046346e-02 8.59218910e-02 -2.69172415e-02 -1.75508182e-03
3.87416431e-03 -2.11908249e-03 -5.98554593e-03 1.32124824e-02
-7.22693698e-03 3.56099289e-03 -7.86052924e-03 4.29953635e-03
-1.33406427e-02 2.94481106e-02 -1.61074679e-02]
[ -1.42659002e-03 2.15499196e-03 2.23391340e-03 5.85207134e-04
-5.51932165e-03 8.33742879e-03 8.64276756e-03 2.26410269e-03
-4.70885402e-03 7.11314566e-03 7.37364776e-03 1.93163764e-03
-1.21727986e-02 1.83881018e-02 1.90615226e-02 4.99345176e-03
-2.40529864e-03 -1.63234770e-02 1.87287778e-02 -5.11422716e-02
6.40287027e-02 -1.28864162e-02 -1.72071008e-03 -1.16775408e-02
1.33982506e-02 1.02370558e-03 6.94734277e-03 -7.97104836e-03
-3.83513537e-03 -2.60270163e-02 2.98621524e-02] ... [omitted] ]