may, 2025
i put thing in black box, then i want thing back, but how? what does it take to build an invertible neural network?
in caveman terms, an invertible process is one we know how to undo. moving a banana is invertible, eating the banana is not.
mathematically, an invertible function is a bijective one, and a bijective function is both injective and surjective. these are defined as follows:
in human words, an injective function is such that one output is given by one input exactly, with no two inputs giving the same output, and a surjective function is such that every possible output is mapped to by some input. together, they give a bijection. given any output, i know the input that mapped to it exists and is unique.
example o' clock, the function \(y = 4x + 10\) is invertible, because given any \(y\), i can give you back \(x\) via the inverse function \(x = \frac{y - 10}{4}\). note that this function is analytically invertible, because we can find an analytical form of the inverse function.
without going into the math, it is pretty easy to see that most neural networks are not analytically invertible. for one, the most common choice of activation function, ReLU, would render the whole thing non-invertible:
\[ ReLU(x) = max(0, x) \]
intuitively, if i got a positive number out of ReLU, i would know that that positive number is exactly the number that i put in. but if the output of ReLU is 0, then we would have no idea what we put in. and, since a composition of functions is only invertible if all pieces of the composition are invertible, the usage of ReLU renders the most common linear layer \( y = \phi(w^Tx + b)\) not analytically invertible. beyond linear layers, things like convolutions on images and message passing on graphs are also not invertible under similar reasoning. in both cases, the aggregation step renders the whole thing not analytically invertible - simply put, the result of an addition tell you nothing about what each component might have been.
so is all hope lost? can we go home now?
no, not yet.
never give up.
normalizing flows are a class of generative models. think of stable diffusion and the gpts: its objective is to produce more samples in the distribution of data it is trained on. if i feed it pictures ikea sharks, it should learn how to generate more ikea shark looking plushies.
the key idea of normalizing flows is as follows: if we are able to transforming the data distribution into a simpler distribution via a series of invertible and differentiable mappings, like the standard normal (hence the name), we would be able to generate new samples in the data distribution by simply sampling noise from the standard normal, and pushing the noise back through the inverse functions.
this class of generative models is at its core based on the change of variables formula in probability theory, which gives us a way to compute the resulting (or the pushforward) distribution after applying a function to a distribution.
given \( Z \in \mathbb{R}^D\), a random variable with a known and tractable probability density function \( p_z: \mathbb{R}^D \rightarrow \mathbb{R} \), if \( g \) is an invertible function, and \( Y = g(z) \), then the change of variables formula gives us a way to compute the probability density function of the random variable \( Y\).
\[ p_Y(y) = p_Z(f(y)) \left| \det \frac{\partial f}{\partial y} \right| \]
the inverse of this transformation can be similarly computed as long as \(g\) is invertible. this formula is informative in that it prescribes the necessary condition to building such an invertible layer.
for a more comprehensive overview and a much more mathematically rigorous definition of normalizing flows see this review.
this change of variable formula is also useful for inverse transform sampling, which is not super relevant but super cool, you can read more about that in the bishop & bishop deep learning book, chapter 14..
that formula prescribes the necessary components to building such an invertible layer: we want a bijection \( g \) that is expressive when stacked in layers. and importantly for building and training a model, we need this bijection, its inverse and the determinant of its jacobian to be easily computable. if we had such a function, we would be able to build such a normalizing flow to our hearts contents.
that is, if we had such a function.
and we do.
original paper here
coupling layers as proposed by dihn et al gives us such a invertible function. the core idea behind this sort of layer is to split the input \( x \) to each layer into two pieces, \(x_1\) and \(x_2\) and perform the following computation, with \( m \) being an arbitrarily complex function.
\begin{align*} y_1 &= x_1 \\ y_2 &= x_2 + m(x_1) \end{align*}
a little bit of rearranging very trivially gives us the inverse function:
\begin{align*} x_1 &= y_1 \\ x_2 &= y_2 - m(y_1) \end{align*}
note that both the forward and inverse computation of the coupling layer can happen with just the forward pass of the function \(m\), which can be as simple or as complex as the builder of the model desires it to be.
of course, the coupling layer leaves part of the input completely unchanged, but this issue of expressibility can be fixed simply by having multiple coupling layers, and alternating the block that receives the identity function at each turn. composition of invertible functions is invertible, and so this stacking of coupling layers gives us the analytically invertible that we sought for.
and that, my chums, is two out of three of our requirements to build a usable normalizing flow. all we need now is a fast way to compute the determinant of the jacobian of the function.
do you remember how to compute a matrix determinant?
o_o
the determinant of a matrix \(A\) is nonzero iff \(A\) is invertible, and it intuitively the factor by which space is stretched after applying \(A\).
there is a beautiful 3b1b video with great visualizations on this topic!
how to compute it? well. you can dog it via laplace expansion or LU decomposition or a variety of other methods, giving you a time complexity of \( O(n^3) \). or, if your matrix is well behaved (not a mathematical definition), you can apply a number of properties of the determinant to compute it way faster.