Add post on Kolmogorov-Arnold Networks

This commit is contained in:
Dimitri Lozeve 2024-06-08 12:55:24 +02:00
parent fe066d13ed
commit 681974170e
2 changed files with 127 additions and 0 deletions

View file

@ -0,0 +1,81 @@
---
title: "Reading notes: Kolmogorov-Arnold Networks"
date: 2024-06-08
tags: machine learning, paper
toc: false
---
This paper [cite:@liu2024_kan] proposes an alternative to multi-layer
perceptrons (MLPs) in machine learning.
The basic idea is that MLPs have parameters on the nodes of the
computation graph (the weights and biases on each cell), and that KANs
have the parameters on the edges. Each edge has a learnable activation
function parameterized as a spline.
The network is learned at two levels, which allows for "adjusting
locally":
- the overall shape of the computation graph and its connexions
(external degrees of freedom, to learn the compositional structure),
- the parameters of each activation function (internal degrees of
freedom).
It is based on the [[https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Arnold_representation_theorem][Kolmogorov-Arnold representation theorem]], which
says that any continuous multivariate function can be represented as a
sum of continuous univariate functions. We recover the distinction
between the compositional structure of the sum and the structure of
each internal univariate function.
The theorem can be interpreted as two layers, and the paper then
generalizes it to multiple layer of arbitrary width. In the theorem,
the univariate functions are arbitrary and can be complex (even
fractal), so the hope is that allowing for arbitrary depth and width
will allow to only use splines. They derive an approximation theorem:
when replacing the arbitrary continuous functions of the
Kolmogorov-Arnold representation with splines, we can bound the error
independently of the dimension. (However there is a constant which
depends on the function and its representation, and therefore on the
dimension...) Theoretical scaling laws in the number of parameters are
much better than for MLPs, and moreover, experiments show that KANs
are much closer to their theoretical bounds than MLPs.
KANs have interesting properties:
- The splines are interpolated on grid points which can be iteratively
refined. The fact that there is a notion of "fine-grainedness" is
very interesting, it allows to add parameters without having to
retrain everything.
- Larger is not always better: the quality of the reconstruction
depends on finding the optimal shape of the network, which should
match the structure of the function we want to approximate. Finding
this optimal shape is found via sparsification, pruning, and
regularization (non-trivial).
- We can have a "human in the loop" during training, guiding pruning,
and "symbolifying" some activations (i.e. by recognizing that an
activation function is actually a cos function, replace it
directly). This symbolic discovery can be guided by a symbolic
system recognizing some functions. It's therefore a mix of symbolic
regression and numerical regression.
They test mostly with scientific applications in mind: reconstructing
equations from physics and pure maths. Conceptually, it has a lot of
overlap with Neural Differential Equations
[cite:@chenNeuralOrdinaryDifferential2018;@ruthotto2024_differ_equat]
and "scientific ML" in general.
There is an interesting discussion at the end about KANs as the model
of choice for the "language of science". The idea is that LLMs are
important because they are useful for natural language, and KANs
could fill the same role for the language of functions. The
interpretability and adaptability (being able to be manipulated and
guided during training by a domain expert) is thus a core feature that
traditional deep learning models lack.
There are still challenges, mostly it's unclear how it performs on
other types of data and other modalities, but it is very
encouraging. There is also a computational challenges, they are
obviously much slower to train, but there has been almost no
engineering work on them to optimize this, so it's expected. The fact
that the operations are not easily batchable (compared to matrix
multiplication) is however worrying for scalability to large networks.
* References