paper, my implementation

Recently, a lot of research has focused on building neural networks that can learn to reason. Some examples include simulating particle physics [1], coming up with mathematical equations from data [2].

Figure 1, shows an example of a reasoning task. In these tasks, we have to learn the fundamental properties of the environment given some data. In this example, the model has to first learn the meaning of the words furthest, color, pair.

Many reasoning tasks can be formalized as graph problems and message passing [3] has been shown to be a key component to modern graph neural networks. But why can GNNs solve these problems and MLPs cannot, even though they have the same expressive power? The failure of MLPs to solve these reasoning tasks can further be stated as, why do MLPs fail to generalize to these reasoning tasks?

Formally, we will be trying to answer the following question

When does a network structure generalize better than other, even if they have same expressive power?

Algorithm point of view

Let us begin with an observation that the reasoning process resembles algorithms. To solve the reasoning problem of Figure 1, we can come up with the following solution

  1. Find the location of all pairs of objects.
  2. Determine the pair that is furthest (this is just a heuristic that you decide).
  3. Return the color of the object.

The above three steps resemble an algorithm, where we are defining a step-by-step procedure to solve a problem.

From the paper,

We will build on this observation and study how well a reasoning algorithm aligns with the computation graph of the network. Intuitively, if they align well, the network only needs to learn simple algorithm steps to simulate the reasoning process, which leads to better efficiency.

The authors of the paper formalize the above as algorithm alignment.

Algorithm alignment

Every neural network architecture that we build has an underlying computation structure. MLPs resemble a for loop kind of computation structure as they are applied to vector inputs. In the case of GNNs, we aggregate information from the neighbors implying a dynamic programming (DP) type of computation (shown in the paper).

How GNNs relate to DP

Bellman-Ford algorithm [4] is an algorithm used to solve the shortest path problem. The main step of the algorithm is

for u in Nodes:
    d[k][u] = min_v d[k-1][v] + cost(v,u)

where, k = iteration number (1,2,…,num_nodes-1)

Now let’s see how we will do the same use GNN. The message passing algorithm is

huk=UPDATE(huk1,AGGREGATE(FUNC(huk1,hvk1)))h_u^k = UPDATE(h_u^{k-1}, AGGREGATE(FUNC(h_u^{k-1}, h_v^{k-1})))

Now let

  • UPDATE = identity function
  • In most cases FUNC is MLP.
  • AGGREGATE = minimum

Using this information the message passing equation becomes huk=minvMLP(huk1,hvk1)h_u^k = min_v MLP(h_u^{k-1}, h_v^{k-1})

The above equation now resembles closely the Bellman-Ford algorithm. If we had used MLP to solve this problem then MLP would have to learn the structure of the entire for-loop as its computation structure resembles a for loop which is expensive.

This is the main point of the paper. If we can find a suitable underlying reasoning algorithm for the reasoning task, then we can use neural network structures that better align with the underlying algorithm structure. This will make the task easy to learn and improve generalization.

Maximum value difference

In the paper, they do experiments on four reasoning tasks

  • Maximum value difference
  • Furthest pair
  • Monster trainer
  • Subset sum

I try to reproduce the results of the maximum value difference task. The task is simple, given a vector find the difference between the maximum and minimum value. But why is it important?

In a lot of reasoning tasks, we are required to answer questions related to summary statistics (like count, min, max). For example, “How many objects are either small cylinders or red things?”. In the case of GNNs, we can simulate the reasoning algorithm by using MLP to extract features from nodes and then use aggregation to come up with the answer. In this case, MLP has to only learn to extract local features. On the other hand, if we only used MLPs to solve this problem. Then the MLP must learn a complex for-loop and therefore needs more data to converge.

The maximum value difference task is stated as:

  1. A training sample consists of 25 treasures ($X$).
  2. For each treasure ($X_i$), we have $X_i = [h_i,h_2,h_3]$ where
    • $h_1$ is 8-dim location vector sampled uniformly from [0,20]
    • $h_2$ is value sampled uniformly from [0,100]
    • $h_3$ is color sampled uniformly from [1,6]

For maximum value difference task, we have to find the difference between the maximum value and the minimum value for each training sample.


To input the training sample, we simply concatenate the vector representation of all 25 treasures and then feed them into a MLP. We solve this problem as a classification problem where the task is to predict a value from 0 to 100 i.e. 101 classes.

The code to generate the data is in A quick summary of data generation process is shown below

for i in range(num_samples):
    location = np.random.uniform(0, 20, size=(25, 8))
    value = np.random.randint(100, size=(25, 1))
    color = np.random.randint(1, 6, size=(25, 1))

    min_val = np.min(value)
    max_val = np.max(value)

    t = np.concatenate((location, value, color), axis=1).flatten()

The code to create MLP is in

Linear(in_features=250, out_features=256, bias=True),
Linear(in_features=256, out_features=256, bias=True),
Linear(in_features=256, out_features=256, bias=True),
Linear(in_features=256, out_features=256, bias=True),
Linear(in_features=256, out_features=256, bias=True),
Linear(in_features=256, out_features=101, bias=True)

The code to test the model is in mlp.ipynb. MLP achieves around 8% accuracy on the validation data. This is the expected result.


Construct a fully connected graph with 25 nodes (each node representing a treasure). I use pytorch_geometric to implement the GNN.

The code to generate the data is in and the code to construct GNN is in

The best GNN model got to 98.5% accuracy (maybe with hyperparameter search 100% accuracy can be achieved). But it demonstrates the idea that GNN can easily learn summary statistics which are a key component of reasoning problems.


The concept of algorithm alignment can be applied to any reasoning algorithm. If we can come up with a suitable algorithm to solve the reasoning problem, then we can design a network with a similar structure to learn it. If we have no prior knowledge about the structure of the reasoning algorithm then neural architecture search over algorithm structures will be needed.

If you want to read more about Graph Deep Learning, see my other posts, or follow me on twitter.