A path towards autonomous machine intelligence

Introduction

Three main challenges in AI:

  1. How can machines learn to represent and predict the world, through observation?
  2. How can gradient-based learning implement reasoning?
  3. How can machines split action into multiple hierarchical levels and time scales?

Joint Embedding Predictive Architecture (JEPA)

JEPA is not generative –> cannot predict $y$ from $x$. Instead, it captures the dependencies between $x$ and $y$ without explicitly generating predictions of $y$.

JEPA flow:

  1. The two variables $x$ and $y$ are fed to two encoders producing two representations $s_x$ and $s_y$. These two encoders may be different, allowing $x$ and $y$ to be different in nature (e.g. video and audio).
  2. A predictor module predicts the representation of $y$ from the representation of $x$, depending on a latent variable $z$.
  3. The energy is defined as the prediction error in the representation space: Ew(x,y,z)=D(sy,Pred(sx,z))E_w(x,y,z) = D(s_y, \text{Pred}(s_x, z))
  4. The overall energy is obtained by minimising over $z$: \begin{align*} \hat{z} & = \text{argmin}_{z\in \mathcal{Z}} E_w(x,y,z) = \text{argmin}_{z\in \mathcal{Z}} D(s_y, \text{Pred}(s_x, z)) \\ F_w(x,y) &= \text{min}_{z\in\mathcal{Z}}E_w(x,y,z) = D(s_y, \text{Pred}(s_x, z))\end{align*} ![[Screen Shot 2022-09-15 at 9.34.34 pm.png]]

The main advantage of JEPA is that it doesn’t need to generate $y$, because it only predicts its representation $s_y$ in the representation space. Thus, JEPA doesn’t need to predict every detail of $y$. In other words, the encoder of JEPA learns to produce an abstract representation from which irrelevant details have been eliminated.

However, a problem with this is that the world is inherently unpredictable. Lecun gives several reasons for this:

  • The world is inherently stochastic
  • The world is chaotic and hence difficult to predict without infinitely precise perception
  • The world is only partially observable

So, we need some way to represent the multiplicity of values of $y$ comparable with $x$. Lecun suggests two ways to do this:

  1. Multi-modality through encoder invariance: in which all they possible $y$’s given from an $x$ have the same energy; that is, they all map to $s_y$. Whilst this means that JEPA can’t recreate the outputs from the inputs, this is similar to how we operate in the real world. One might imagine this like a car on a freeway, just after a freeway entry. Did the car come from the entry, or has it been on the freeway the entire time? Both are plausible answers, and suggest that we understand the world. And yet we couldn’t reconstruct the input $x$ (i.e. where the car came from) without further information.
  2. Multi-modality through latent variable predictor: in contrast, the predictor may use a latent variable $z$ to capture the information necessary to predict $s_y$ that is not present in $s_x$. This is similar to the example above. In this case, the latent variable $z$ might be binary indicating whether the car came from the entry ($z=0$) or the freeway ($z=1$). If it came from the freeway, then the value $z=1$ will produce a lower energy $D(s_y, \tilde{s}_y)$ than $z=0$. Generally, we vary $z$ over a set $\mathcal{Z}$, and the predictor module of JEPA produces a set of plausible predictions $\text{Pred}(s_x, \mathcal{Z})={\tilde{s}_y = \text{Pred}(s_x, z) \hspace{2mm} \forall z \in \mathcal{Z}}$.

Training a JEPA

Lecun proposes training a JEPA through non-contrastive methods, as contrastive methods tend to become inefficient in high dimensions. The dimension we are referring to is of course $s_y$.

Training energy-based models

Training an EBM consists in constructing an architecture to compute the energy function $F_w(x,y)$ parametrised with a parameter vector $w$. The training process must find a $w$ vector that gives the right shape to the energy function: values of $y$ associated with $x$ in the training set will give low energies. ![[Screen Shot 2022-09-15 at 9.56.36 pm.png | 300]] So we need a loss function $L(x,y,F_w(x,y))$ that we can minimise in order to make the energy of the pair $(x,y)$ lower than the energies of $(x,\hat{y})$ for any other $\hat{y}$. It’s easy to make energy of a training sample $(x,y)$ low: just make the loss an increasing function of the energy. However, we have a problem: how do we make the energy for $(x, \hat{y})$ small? The energy function could just learn to give the same energy value to all values of $y$, thus suffering a collapse and becoming flat. In essence, we want to make sure Fw(x,y^)>Fw(x,y)y^yF_w(x,\hat{y}) > F_w(x,y) \hspace{5mm} \forall \hat{y}\neq y We are used to neural networks that output a single prediction $\tilde{y}$ for any $x$. These do not suffer from collapse, because they give zero energy whenever $y=\tilde{y}$, and a higher energy otherwise. Other architectures, particularly generative ones, can collapse.

Lecun notes two methods to prevent collapse: contrastive methods and regularised methods.

![[Screen Shot 2022-09-15 at 10.07.57 pm.png 500]]

Contrastive methods use a loss whose minimisation has the effect of pushing down on concrete training pairs $(x,y)$ and pulls up the energy of artificially generated “contrastive” samples $(x,\hat{y})$. Of course, we pick $\hat{y}$ so that it is outside the region of high data density and hence is unlikely to lie on the manifold representing the structure of the reality we’re facing. These are used widely in neural networks, particularly for Siamese networks where the training input is $x$ and $y$, with $x$ being a distorted version of $y$, and $\hat{y}$ is a different training sample. However, contrastive methods have two main issues:

  • We need to pick a suitable $\hat{y}$ that doesn’t confuse the network i.e. we might confuse the network by picking a $\hat{y}$ that could have actually from our given $x$, albeit accidentally. There are many methods to do this, such as Markov-Chain Monte Carlo methods, but these can be expensive, particularly in high dimensions.
  • Another issue stemming from the curse of dimensionality is that the higher-dimensional our target vector $y$, the more contrastive samples we need to ensure the energy is higher in all dimensions “unoccupied by the local data distribution”. The number of samples required may grow exponentially with the dimension of $y$.

Regularised methods somewhat address the curse of dimensionality by simultaneously minimising the energies of training pairs and minimising the volume of space around the samples to which the model assigns low energy. It does this by “shrink-wrapping” the regions of high data density, achieved through a regularisation term in the loss. However, the design will depend on the model architecture.

Non-contrastive methods for training JEPA

We want to use regularisers to minimise the volume of space that can take on low energy values. Lecun notes that this can be achieved with four criteria:

  1. Maximise the information content of $s_x$ about $x$
  2. Maximise the information content of $s_y$ about $y$
  3. Make $s_y$ easily predictable from $s_x$
  4. Minimise the information content of the latent variable $z$ used in the prediction

The balance between these four criteria produces a harmonious regularisation effect. For instance, the first two criteria force the encoding to actually do something; otherwise, the encodings could be set to constants, making the energy constant over the input space. This prevents informational collapse. Criterion 3 arises from the energy term $D(s_y, \tilde{s}_y)$ and ensures $y$ is obtainable from $x$ in representation space. Finally, Criterion 4 simply prevents us from copying the latent variable to the encoding $s_y$, making $D(s_y, \tilde{s}_y)$ equal to zero but blocking off any effective encoding because we rely entirely on the latent variable.

[!note] Regularisers that minimise information content of latent variables Imagine that the energy function is augmented by a regularisation term on $z$ of the form $R(z)=\alpha \sum_{i=1}^d |z_1|$ i.e. the $L_1$ norm of $z$. This will drive $\tilde{z}$ to be sparse. As with classical sparse encoding, this will cause the region of low energy to be approximated by a union of low-dimensional manifolds (a union of low-dimensional linear subspaces if $\text{Pred}(s_x, z)$ is linear in $z$), whose dimension will be minimised by the $L_1$ regulariser.

In this way, the main idea of JEPA is realised: it can choose to ignore detaills of the inputs that are not easily predictable, and yet our first two criteria ensure that we don’t ignore necessary details. Lecun then asks how we might maximise the information content of $s_y$ given a parametrised, deterministic encoding function $s_y = \text{Enc}_w(y)$. $s_y$ is maximally informative about $y$ if the function $\text{Enc}_w(y)$ is minimally surjective i.e. if the volume of sets of $y$ that map to the same $s_y$ is minimal. The same reasoning applies to the $x$ encoder. We now want to turn this criterion into a differentiable loss.

VICReg

VICReg method (Bardes et al., 2021) uses two sub-criteria: (1) the components of $s_x$ must not be constant, and (2) the components of $s_x$ must be as independent of each other as possible.

They achieve this by “expanding” $s_x$ and $s_y$ to higher-dimensional embeddings $v_x$ and $v_y$ through a non-linear mapping, and then minimising a differentiable loss with two terms:

  1. Variance: maintains the standard deviation of each component of $s_y$ and $v_y$ above a certain threshold.
  2. Covariance: the covariance between different pairs of $v_y$ components are pushed towards zero. This has the effect of decorrelating the components of $v_y$, making components of $s_y$ somewhat independent.

Same criteria obviously apply to $s_x$ and $v_x$. The third criterion is then implementing the prediction error $D(s_y, \tilde{s}_y)$. The fourth criterion is used when the predictor relies on a latent variable whose information content must be minimised.

[!quote] While contrastive methods ensure that representations of different inputs in a batch are different, VICReg ensures that different components of representations over a batch are different. VICReg is contrastive over components, while traditional contrastive methods are contrastive over vectors, which requires a large number of contrastive samples.

With the above training criteria, JEPA strikes a balance between the completeness and predictability of the representations. The encoders and predictor together form an inductive bias on what information is predictable; we want to represent as much information as possible, but we want our representation to contain only information that is relevant for predicting $s_y$.

Hierarchical JEPA (H-JEPA)

Most importantly, non-contrastive training of JEPA is designed to learn abstractions of the world around us. Not only does it learn to abstract away unnecessary details, but it does so in a way such that the resulting encoding allows prediction of the next one. Even if it doesn’t, this information can be pushed into the predictor’s latent variable, but even this process is monitored by the fourth criteria above.

Lecun has the idea of extending JEPA by performing hierarchical stacking. In the diagram below, JEPA-1 extracts low-level representations and performs short-term predictions. JEPA-2 takes the representations extracted by JEPA-1 as inputs and extracts higher-level representations, perhaps through convolutional modules, which are useful on a longer-term horizon. In doing so, JEPA-2 abstracts away even more details, giving coarser but more general predictions of the state of the world.

![[Screen Shot 2022-09-16 at 7.32.20 am.png 600]]

This idea is best illustrated with Lecun’s example of driving a car. In this case, the JEPA-1 of our brain monitors the position of the car, the steering wheel and pedals and can produce a highly accurate trajectory prediction of the car over the next few seconds. Conversely, when planning an hour long trip, there may be low-level details that are impossible to predict: traffic density, the exact route we’ll take (due to closed roads or tolls) and traffic light signals. However, by abstracting away much of this detail, we can still give an accurate prediction of when we might arrive at our destination, and the approximate trajectory we’ll take to get there. This represents the JEPA-2 module of our brain. Of course, there may be many different modules, all stacked in a hierarchical manner such that each one performs more abstraction than the last.

Hierarchical planning

Given Kahneman and Tversky’s System-1 and System-2 thinking, we want to know if H-JEPA could be used to plan hierarchically. Most approaches to this require a pre-existing vocabulary of actions. In contrast, Lecun proposes that the intermediate representations of action plans should also be learned.

Let’s construct our H-JEPA System-2 module as follows. Suppose we read in the state of the world $x$. We can then encode $x$ to give $s[0]$ and create a prediction $s[1]$ about the state of the world, given a chosen action $a[0]$. We can repeat this for as long a sequence of actions as we’d like. However, what if we introduced an encoder which took the original encoding $s[0]$ and produced an even higher-level encoding? Then, after predicting the state of the world given some higher-level actions $a2[i]$, we can create subgoals $C(s[i])$ (lower-level cost modules) that are fed to the lower layer. The lower layer then uses these costs to generate an action sequence that minimises the subgoal costs. This could be extended to an $n$-level H-JEPA module, discretising complex tasks into smaller and smaller subgoals.

![[Screen Shot 2022-09-16 at 7.45.54 am.png 500]]

Whilst this process is sequential top-down and hence greedy, we may instead optimise high and low-level actions jointly by iterating over them. This goes back to previous work in control theory, where an action is merely a condition to be satisfied by the lower level.

The best way to handle low-level prediction would likely be extracting simple local feature vectors and displacing those feature vectors to the next time step, according to predicted motion. The latent variables in this case would encode a map of displacements. Conversely, higher-level and longer-term prediction relies on objects and their interactions, meaning the evolution may be best modeled by a transformer. This is because such as an architecture is equivariant to permutation and is able to capture interactions between discrete objects (Vaswani et al., 2017; Carion et al., 2020; Battaglia et al., 2016).

Handling uncertainty

As discussed above, the real world is inherently unpredictable, with multiple predictions $y$ being compatible with a current world-state $x$. To deal with this, we push the stochasticity of $y$ into a latent variable. But how exactly do we perform hierarchical planning in the presence of uncertainty?

Uncertainty can be handled by giving the predictors latent variables. These contain information about the prediction that cannot be derived from prior observation. They must also be regularised to prevent an energy collapse, and to minimise the reliance of the system on the information from the latent variable. At planning time, latent variables are sampled from distributions by applying a Gibbs distribution to the regularisers. Each sample leads to a different prediction. To produce consistent latent sequences, the parameters of the regulariser can be functions of previous states and retrieved memories. The distribution we sample latents from is determined by the regulariser.

![[Screen Shot 2022-09-16 at 7.54.01 am.png 500]]

However, as any computer scientists would have noticed, having a variety of possible values for the latent variable at each step is problematic. After $t$ time steps, if each latent variable has $k$ possible discrete values, the number of trajectories grows as $k^t$. Since the values are discrete, it is unlikely we can use gradient-based optimisation methods, and would instead have to rely on direcrted search and pruning strategies (e.g. Monte-Carlo Tree Search).

Given a sample of all latents, the predictor then computes the optimal action sequence at each level. We may also repeat drawings of the latents to cover all plausible outcomes. In this way, we don’t just minimise the expected cost, but also the uncertainty of the expected cost.

Keeping track of the state of the world

Traditionally, deep learning models communicate states through vectors or tensors, which is inefficient when we are often only changing minor aspects of the world at each time step. For instance, any given action by the agent will only change a small part of the world. Moving a bottle from the kitchen to the dining room won’t affect the geopolitical stability in Kenya. Only the states of the bottle, the kitchen and the dining room are altered.

Hence, our world should probably be captured in some form of writable memory that doesn’t require parsing over the entire thing each time we change it. A conventional key-value associative memory can be used for this purpose, similar to the context of memory-augmented networks (Bordes et al., 2015; Sukhbaatar et al., 2015; Miller et al., 2016) and entity networks (Henaff et al., 2017).

The output of the world model at a given time step is a set of query-value pairs $(q[i], v[i])$ which are used to modify existing entries in the world-state memory, or to add new entries. Given a query $q$, the world-state memory returns \begin{align*} \text{Mem}(q) &= \sum_j c_jv_j \\ \tilde{c}_j &= \text{Match}(k_j, q) \\ c &= \text{Normalise}(\tilde{c})\end{align*} where the $k_j$ are keys, the $v_j$ are stored values, function $\text{Match}(k, q)$ measures a divergence or dssimilarity between a key and a query, vector $c contains scalar coefficients $c_j$, and function $\text{Normalise}(\tilde{c})$ performs some sort of competitive normalisation or thresholding, such as $c_j = \exp(\tilde{c}_j)/[\gamma + \sum_k \exp(\tilde{c}_k)]$, where $\gamma$ is a positive constant.

Writing a value $r$ using query (or address) $q$ into the memory can be done by updating existing entries: \begin{align*} \tilde{c}_j &= \text{Match}(k_j, q) \\ c &= \text{Normalise}(\tilde{c}) \\ v_j &= \text{Update}(r, v_j, c_j)\end{align*} Function $\text{Update}(r,v,c)$ may be as simple as $cr+(1-c)v$.

If the query is distant from all keys, the memory may allocate a new entry whose key is $q$ and corresponding value is $r$. The $\gamma$ constant in the Normalise function may serve as a threshold for acceptable key-query divergence.

One can view each entry as representing the state of an entity in the world. In the above example of the bottle, the world model may contain keys $k_\text{bottle}, k_\text{kitchen}, k_\text{dining room}$. The initial value of $v_\text{bottle}$ encodes its location as being “kitchen”, and so on. After the event, the location and contents are updated. All of these operations can be done in a differentiable manner, and would hence allow to back-propagate gradients through them.

Comparing JEPA and closed loop feedback (LeCun vs Ma)

Large language models (LLMs)

The main reason LeCun believes scaling will not be enough to lead to emergence is that LLMs are, by construction, generative. Inputs are converted into a sequence of discrete tokens, which works well for text, but flounder in the face of high-dimensional and continuous spaces.

And yet LeCun remains ever the empiricist. “I speculate that common sense may emerge from learning world models that capture the self-consistency and mutual dependencies of observations in the world, allowing an agent to fill in missing information and detect violations of its world model.”

What needs to be done?

According to LeCun, there are several non-trivial problems which constitute the “known unknowns” of his JEPA-based architecture. Probably