# What can neural networks reason about?

Neural network structures that align better with the underlying reasoning algorithm generalize better in reasoning tasks.

- Algorithm point of view
- Algorithm alignment
- How GNNs relate to DP
- Maximum value difference
- Conclusion

Recently, a lot of research has focused on building neural networks that can learn to reason. Some examples include simulating particle physics [1], coming up with mathematical equations from data [2].

Figure 1, shows an example of a reasoning task. In these tasks, we have to learn the fundamental properties of the environment given some data. In this example, the model has to first learn the meaning of the words *furthest*, *color*, *pair*.

Many reasoning tasks can be formalized as graph problems and message passing [3] has been shown to be a key component to modern graph neural networks. But why can GNNs solve these problems and MLPs cannot, even though they have the same expressive power? The failure of MLPs to solve these reasoning tasks can further be stated as, why do MLPs fail to generalize to these reasoning tasks?

Formally, we will be trying to answer the following question

When does a network structure generalize better than other, even if they have same expressive power?

## Algorithm point of view

Let us begin with an observation that the reasoning process resembles algorithms. To solve the reasoning problem of Figure 1, we can come up with the following solution

- Find the location of all pairs of objects.
- Determine the pair that is furthest (this is just a heuristic that you decide).
- Return the color of the object.

The above three steps resemble an *algorithm*, where we are defining a step-by-step procedure to solve a problem.

From the paper,

We will build on this observation and study how well a reasoning algorithm

alignswith the computation graph of the network. Intuitively, if they align well, the network only needs to learn simple algorithm steps to simulate the reasoning process, which leads to better efficiency.

The authors of the paper formalize the above as **algorithm alignment**.

## Algorithm alignment

Every neural network architecture that we build has an underlying computation structure. MLPs resemble a for loop kind of computation structure as they are applied to vector inputs. In the case of GNNs, we aggregate information from the neighbors implying a dynamic programming (DP) type of computation (shown in the paper).

## How GNNs relate to DP

Bellman-Ford algorithm [4] is an algorithm used to solve the shortest path problem. The main step of the algorithm is

```
for u in Nodes:
d[k][u] = min_v d[k-1][v] + cost(v,u)
```

where, k = iteration number (1,2,…,num_nodes-1)

Now let’s see how we will do the same use GNN. The message passing algorithm is

$h_u^k = UPDATE(h_u^{k-1}, AGGREGATE(FUNC(h_u^{k-1}, h_v^{k-1})))$Now let

- UPDATE = identity function
- In most cases FUNC is MLP.
- AGGREGATE = minimum

Using this information the message passing equation becomes $h_u^k = min_v MLP(h_u^{k-1}, h_v^{k-1})$

The above equation now resembles closely the Bellman-Ford algorithm. If we had used MLP to solve this problem then MLP would have to learn the structure of the entire for-loop as its computation structure resembles a for loop which is expensive.

This is the main point of the paper. If we can find a suitable underlying reasoning algorithm for the reasoning task, then we can use neural network structures that better *align* with the underlying algorithm structure. This will make the task easy to learn and improve generalization.

## Maximum value difference

In the paper, they do experiments on four reasoning tasks

- Maximum value difference
- Furthest pair
- Monster trainer
- Subset sum

I try to reproduce the results of the maximum value difference task. The task is simple, given a vector find the difference between the maximum and minimum value. But why is it important?

In a lot of reasoning tasks, we are required to answer questions related to summary statistics (like count, min, max). For example, “How many objects are either small cylinders or red things?”. In the case of GNNs, we can simulate the reasoning algorithm by using MLP to extract features from nodes and then use aggregation to come up with the answer. In this case, MLP has to only learn to extract local features. On the other hand, if we only used MLPs to solve this problem. Then the MLP must learn a complex for-loop and therefore needs more data to converge.

The maximum value difference task is stated as:

- A training sample consists of 25 treasures ($X$).
- For each treasure ($X_i$), we have $X_i = [h_i,h_2,h_3]$ where
- $h_1$ is 8-dim location vector sampled uniformly from [0,20]
- $h_2$ is value sampled uniformly from [0,100]
- $h_3$ is color sampled uniformly from [1,6]

For maximum value difference task, we have to find the difference between the maximum value and the minimum value for each training sample.

### MLP

To input the training sample, we simply concatenate the vector representation of all 25 treasures and then feed them into a MLP. We solve this problem as a classification problem where the task is to predict a value from 0 to 100 i.e. 101 classes.

The code to generate the data is in min_max_mlp_data.py. A quick summary of data generation process is shown below

```
for i in range(num_samples):
location = np.random.uniform(0, 20, size=(25, 8))
value = np.random.randint(100, size=(25, 1))
color = np.random.randint(1, 6, size=(25, 1))
min_val = np.min(value)
max_val = np.max(value)
t = np.concatenate((location, value, color), axis=1).flatten()
```

The code to create MLP is in model_mlp.py.

```
Linear(in_features=250, out_features=256, bias=True),
ReLU(inplace=True),
Linear(in_features=256, out_features=256, bias=True),
ReLU(inplace=True),
Linear(in_features=256, out_features=256, bias=True),
ReLU(inplace=True),
Linear(in_features=256, out_features=256, bias=True),
ReLU(inplace=True),
Linear(in_features=256, out_features=256, bias=True),
ReLU(inplace=True),
Linear(in_features=256, out_features=101, bias=True)
```

The code to test the model is in mlp.ipynb. MLP achieves around 8% accuracy on the validation data. This is the expected result.

### GNN

Construct a fully connected graph with 25 nodes (each node representing a treasure). I use pytorch_geometric to implement the GNN.

The code to generate the data is in min_max_graph_data.py and the code to construct GNN is in model_gnn.py.

The best GNN model got to 98.5% accuracy (maybe with hyperparameter search 100% accuracy can be achieved). But it demonstrates the idea that GNN can easily learn summary statistics which are a key component of reasoning problems.

## Conclusion

The concept of algorithm alignment can be applied to any reasoning algorithm. If we can come up with a suitable algorithm to solve the reasoning problem, then we can design a network with a similar structure to learn it. If we have no prior knowledge about the structure of the reasoning algorithm then neural architecture search over algorithm structures will be needed.

- [1] Learning to Simulate Complex Physics with Graph Networks
- [2] Discovering Symbolic Models from Deep Learning with Inductive Biases
- [3] Relational inductive biases, deep learning, and graph networks
- [4] Bellman-Ford algorithm wikipedia

If you want to read more about Graph Deep Learning, see my other posts, or follow me on twitter.