AI ALIGNMENT FORUM
AF

MATS ProgramAI
Frontpage

9

[Short version] Information Loss --> Basin flatness

by Vivek Hebbar
21st May 2022
1 min read
0

9

MATS ProgramAI
Frontpage
New Comment
Moderation Log
Curated and popular this week
0Comments
Mentioned in
25Information Loss --> Basin flatness

This is an overview for advanced readers.  Main post: Information Loss --> Basin flatness

Summary:

Inductive bias is related to, among other things:

  • Basin flatness
  • Which solution manifolds (manifolds of zero loss) are higher dimensional than others.  This is closely related to "basin flatness", since each dimension of the manifold is a direction of zero curvature.

In relation to basin flatness and manifold dimension:

  1. It is useful to consider the "behavioral gradients" ∇θf(θ,xi) for each input. 
  2. Let G be the matrix of behavioral gradients.  (The ith column of G is gi=∇θf(θ,xi)).[1]  We can show that dim(manifold)≤N−Rank(G).[2]
  3. Rank(Hessian)=Rank(G).[3][4]
  4. Flat basin  ≈  Low-rank Hessian  =  Low-rank G  ≈  High manifold dimension
  5. High manifold dimension  ≈  Low-rank G  =  Linear dependence of behavioral gradients 
  6. A case study in a very small neural network shows that "information loss" is a good qualitative interpretation of this linear dependence.
  7. Models that throw away enough information about the input in early layers are guaranteed to live on particularly high-dimensional manifolds.  Precise bounds seem easily derivable and might be given in a future post.

See the main post for details.

  1. ^

    In standard terminology, G is the Jacobian of the concatenation of all outputs, w.r.t. the parameters.

  2. ^

    N is the number of parameters in the model.  See claims 1 and 2 here for a proof sketch.

  3. ^

    Proof sketch for Rank(Hessian)=Rank(G):

    •  span(g1,..,gk)⊥ is the set of directions in which the output is not first-order sensitive to parameter change.  Its dimensionality is N−rank(G).
    • At a local minimum, first-order sensitivity of behavior translates to second-order sensitivity of loss.
    • So span(g1,..,gk)⊥ is the null space of the Hessian.
    • So rank(Hessian)=N−(N−rank(G))=rank(G)
  4. ^

    There is an alternate proof going through the result Hessian=2GGT.  (The constant 2 depends on MSE loss.)