Machine Learning on Graph Data*p1i3tja-BJqlp7Bh

Original Source Here

Machine Learning on Graph Data

Photo by NASA on Unsplash


Many real-world data-sets can naturally be framed as graphs. For example, on online platforms such as social networks, users can be represented as nodes, and follows or likes can be represented as edges.

However, when building models on data from these domains, people often simplify the problem by ignoring the underlying graph structure. In doing so, machine learning practitioners ignore useful information that would help contextualize an entity (e.g. a user) in the context of the broader network they are a part of.

In this post, I will cover various methods, ranging from very simple to complex, for incorporating graph information into machine learning models. I will mainly focus on tasks that are the node level, such as node classification, and not tasks that are at the level of the entire graph, as they are more common in industry.

What are graphs?

Graphs are data structures that encode relationships between pairs of entities. Entities in the graph are referred to as nodes, and relationships are referred to as edges. The edges may or may not be directed, which refers to whether or not an edge applies bidirectionally. In a directed graph, the existence of an edge from A → B does not imply the existence of an edge from B →A, whereas in an undirected graph, it does.

For example, on a social network, one graph may connect users to other users based on who they follow. In this case, the nodes would be users, and the edges would encode who follows who. The edges in this graph would be directed, since a user does not have to follow everyone that follows them.

Nodes in this graph also have features, such as the age of the account, or which country the user is from. These attributes distinguish nodes from each other, and provide additional context.

Photo by Kier In Sight on Unsplash

In this social network graph, we would observe many of the phenomena we know about social networks, such as:

  • Some users, such as celebrities, would have many more edges coming in then out, since celebrities have millions of fans but rarely follow everyone back.
  • Some users, such as bots, would have the opposite, and likely have many more edges going out than going in, since bots tend to follow random people but real users would not follow a bot.
  • Users will typically form clusters of densely connected groups consisting of users that tend to know each other. Groups of friends or people within the same interest group will tend to follow many of the same people.

Machine learning methods

Now we will begin covering how to leverage graph information in your models. These methods range from simple statistics to fully trainable graph neural networks, and each has its own advantages and disadvantages.

Graph statistics

The simplest way to leverage graph information is to calculate basic counts and ratios and use them as numerical features.

Some graph statistics include:

  • Degree, which counts how many edges a given node has.
  • Centrality, which measures, abstractly, how important a given graph is to the connectivity of the overall graph. There are various types of centrality metrics, such as betweenness centrality and closeness centrality, but at a high level, they are higher for nodes which lie in paths that efficiently connect many nodes to each other.
  • Clustering coefficient, which effectively measures the density of a node’s local portion of the graph. Nodes who have neighbors that are all connected to each other will have a higher clustering coefficient.

These graph statistics are a great place to start, and have easy to digest, standardized meanings.

However, they are relatively limited in what information they can convey:

  • They are unable to leverage node features at all. All nodes with the same values for these summary statistics are indistinguishable from each other.
  • There is no learnable component in the production of these features. We cannot fit a custom objective or train them jointly with a downstream task.

Node embeddings

The next way to leverage graph information is to learn node embeddings for each node in the graph, and use these embeddings as features in a downstream model. Node embeddings are learnable vectors of numbers that we can map to each node in the graph, allowing us to learn a unique representation for each node.

The most common way to learn these embeddings (from the DeepWalk and Node2vec papers) is to enforce that nodes close to each other have similar representation. They use the concept of a random walk, which involves beginning at a given node and randomly traversing edges, to produce pairs of nodes that are nearby each other. The embeddings are trained by maximizing the cosine similarity between nodes that co-occurred in random walks. This training objective leverages the homophily assumption, which states that nodes that are connected to each other tend to be similar to each other.

These node embedding methods allow us to learn task-independent representations for each node, and potentially have more representation power than the previous section’s methods, since they can learn independent representations for each node.

However, while a step in the right direction, these methods still have downsides:

  • They do not use node features at all. They assume that close-by nodes are similar without actually using the node features to confirm this assumption.
  • They rely on a fixed mapping from node to embedding (i.e. this is a transductive method). This means that for dynamic graphs, where new nodes and edges may be added, the algorithm must be re-ran from scratch, and all node embeddings need to be recalculated. In real-world problems, this is quite a big issue, as most online platforms have new users signing up every day, and new edges being created constantly.

Graph convolutional networks

The last method I will cover is the graph convolutional network (GCN). This is a more advanced way to perform graph representation learning that overcomes many of the shortcomings of the previous node embedding methods.

A graph convolutional network learns representations of nodes by learning a function that aggregates a node’s neighborhood (the set of nodes connected to the original node), using both graph structure and node features. These representations are a function of a node’s neighborhood and are not hardcoded per node (i.e. this is an inductive method), so changes in graph structure do not require re-training the model.

Conceptually, a single layer of a GCN this can be simply thought of as taking a weighted average of the node features in the original node’s neighborhood, in which the weights are learned by training the network. We can then stack these GCN layers to produce aggregations that use more of the graph. For each GCN layer that we add, we expand the span of the subgraph used to produce a node’s embedding by 1 hop.

GCNs can be trained in both supervised and unsupervised ways. For supervised training, we simply train against the label that we have for our downstream task. For unsupervised training, there are several options, but the most common way is similar to the Node2vec/DeepWalk method, in which we enforce that close-by nodes have similar representations.

GCNs overcome many of the issues of the previous methods, as 1) they are trainable end-to-end, 2) they fully leverage node features, and 3) they do not rely on a fixed mapping from node to embedding.

Because of these strengths, this style of graph neural network has been used in many real-world industry systems such as recommender systems (e.g. PinSage) and fraud detection systems.


In this blog post, we covered three different ways to leverage graph information in machine learning models, ranging from simple graph statistics to end-to-end trainable graph neural networks.

These methods allow machine learning models to leverage the rich graph context that is present in many problems. After reading this post, hopefully it is clearer how you can leverage graph information in your models.


Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: