Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update node_representation_learning.md #59

Merged
merged 1 commit into from
Jan 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions docs/use_cases/node_representation_learning.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

Of the various types of information - words, pictures, and connections between things - relationships are especially interesting; they show how things interact and create networks. But not all ways of representing these relationships are the same. In machine learning, how we do vector represention of relationships is consequential for performance on a wide range of tasks.

We go through how to set up a Bag-of-Words (BoW) approach to representing relationship data, and then two other approaches - Node2Vec and GraphSAGE. We compare and evaluate how well each approach represents academic articles in the Cora citation network, by measuring their performance on real-life classification and similarity tasks.
We evaluate several approaches to vector representation on their ability using a real-life use case: how well they classify academic articles and replicate a citation graph using Cora citation network. We look first at Bag-of-Words. Because BoW doesn't represent the network structure, we examine Node2Vec...(improvement, but static) and GraphSAGE (for dynamic networks).
Finally, BoW has another weakness. It also fails to capture semantic meaning. LLM embeddings, on the other hand, are designed to represent semantic meaning. We look at how LLM-only, and then Node2Vec + LLM, and GraphSAGE trained on LLM compared to our first set of approaches: BoW alone, Node2Vec + BoW, and GraphSAGE trained on BoW.


**Our dataset: Cora**
Expand Down Expand Up @@ -57,16 +58,15 @@ In this plot, we define groups (shown on the y-axis) so that each group has abou

The plot demonstrates how connected nodes usually have higher cosine similarities. Papers that cite each other often use similar words. But if we ignore paper pairs with zero similarities (the 0.00-0.00 group), papers that have _not_ cited each other also seem to have a wide range of common words.

Though BoW representations embody _some_ information about article connectivity, BoW features don't contain enough citation pair information to accurately reconstruct the actual citation graph. More specifically, because BoW looks exclusively at word co-occurrence between article pairs, it misses additional information contained in the network structure - namely the semantic relationships and context of words - information that can be used to more accurately represent citation data, and classify articles better. (Because articles that cite each other tend to belong to the same topic, we could achieve improvements in both citation graph reproduction and article classification by representing both citation information and textual information contained in our network.)
Though BoW representations embody _some_ information about article connectivity, BoW features don't contain enough citation pair information to accurately reconstruct the actual citation graph. More specifically, BoW represents documents as unordered sets of words; it ignores word order, and instead treats each word independently. BoW vectors basically represent the frequency of words in a document. Because BoW looks exclusively at word co-occurrence between article pairs, it misses word context data contained in the network structure - data that can be used to more accurately represent citation data, and classify articles better. (Because articles that cite each other tend to belong to the same topic, we can achieve improvements in both citation graph reproduction and article classification by representing both citation information and textual information contained in our network.)

If BoW is not sufficient, what methods might be better at extracting the data inherent but still latent in our dataset?
Let's look at two methods for learning node representations that capture nodes and node connectivity more accurately.
Can we make up for BoW's inability to represent the citation network's structure? What methods might be better at capturing node and node connectivity data better?

## Learning node embeddings with Node2Vec
Node2Vec is built to do precisely this. So is GraphSAGE. First, let's look at Node2Vec.

Node embeddings are vector representations that capture the structural role and properties of nodes in a network.
## Learning node embeddings with Node2Vec

Node2Vec is an algorithm that learns node representations using the Skip-Gram method. The Node2Vec algorithm models the conditional probability of encountering a context node given a source node in node sequences (random walks):
As opposed to BoW vectors, node embeddings are vector representations that capture the structural role and properties of nodes in a network. Node2Vec is an algorithm that learns node representations using the Skip-Gram method; it models the conditional probability of encountering a context node given a source node in node sequences (random walks):

$P(\text{context}|\text{source}) = \frac{1}{Z}\exp(w_{c}^Tw_s) $

Expand All @@ -79,7 +79,7 @@ The random walks are sampled according to a policy, which is guided by 2 paramet
- The return parameter $p$ affects the likelihood of immediately returning to the previous node. A higher $p$ leads to more locally focused walks.
- The in-out parameter $q$ affects the likelihood of visiting nodes in the same or a different neighborhood. A higher $q$ encourages Depth First Search, while a lower $q$ promotes Breadth-First-Search-like exploration.

These parameters are particularly useful for accomodating different networks and tasks. Adjusting the values of $p$ and $q$ captures different characteristics of the graph in the sampled walks. BFS-like exploration is useful for learning local patterns. On the other hand, using DFS-like sampling is useful for capturing patterns from a bigger scale, like structural roles.
These parameters are particularly useful for accommodating different networks and tasks. Adjusting the values of $p$ and $q$ captures different characteristics of the graph in the sampled walks. BFS-like exploration is useful for learning local patterns. On the other hand, using DFS-like sampling is useful for capturing patterns on a bigger scale, like structural roles.

### Node2Vec embeddings

Expand Down Expand Up @@ -145,16 +145,15 @@ Let's also see if Node2Vec also does a better job of **representing citation dat

This time, using Node2Vec we can see a well defined separation; these embeddings capture the connectivity of the graph much better than BoW.


Let’s see if we can further improve classification performance by combining the two information sources, relations and textual features.
But can we further improve classification performance by _combining_ the two information sources, relations (Node2Vec) embeddings and textual (BoW) features?

### Node2Vec + Text-based (BoW) embeddings

A straightforward approach for combining embeddings from different sources is by concatenating them dimension-wise. We have BoW features `v_bow` and Node2Vec embeddings `v_n2v`. The fused representation would then be `v_fused = torch.cat((v_n2v, v_bow), dim=1)`. However, before combining the two representations, let’s look at the L2 norm distribution of both embeddings:
A straightforward approach for combining embeddings from different sources is to concatenate them dimension-wise. We have BoW features `v_bow` and Node2Vec embeddings `v_n2v`. The fused representation would then be `v_fused = torch.cat((v_n2v, v_bow), dim=1)`. Before combining them, we should examine the L2 norm distribution of both embeddings, to ensure that one kind of representations will not dominate the other:

![L2 norm distribution of text based and Node2Vec embeddings](../assets/use_cases/node_representation_learning/l2_norm.png)

From the plot, it's clear that the scales of the embedding vector lengths differ. When we want to use them together, the one with the larger magnitude will overshadow the smaller one. To mitigate this, we can divide each embedding vector by their average length. But we can _further_ optimize performance by introducing a weighting factor ($\alpha$). The combined representations are constructed as `x = torch.cat((alpha * v_n2v, v_bow), dim=1)`. To determine the appropriate value for $\alpha$, we employ a 1D grid search approach. The results are displayed in the following plot.
From the plot above, it's clear that the scales of the embedding vector lengths differ. To avoid the larger magnitude Node2Vec vector overshadowing the BoW vector, we can divide each embedding vector by their average length. But we can _further_ optimize performance by introducing a **weighting factor** ($\alpha$). The combined representations are constructed as `x = torch.cat((alpha * v_n2v, v_bow), dim=1)`. To determine the appropriate value for $\alpha$, we employ a 1D grid search approach. The results are displayed in the following plot.

![Grid search for alpha](../assets/use_cases/node_representation_learning/grid_search_alpha_bow.png)

Expand All @@ -170,9 +169,9 @@ evaluate(x, ds.y)
>>> F1 macro 0.831
```

The results show that by combining the representations obtained from solely the network structure and text of the paper can improve performance. Specifically, in our case, this fusion contributed to a 3.6% improvement from the Node2Vec-only and 15.4% from the BoW-only classifiers.
By combining the representations of the network structure (Node2Vec) and text (BoW) of the paper, we were able to improve performance on article classification. Specifically, the Node2Vec + BoW fusion resulted in a 3.6% improvement from the Node2Vec-only and 15.4% from the BoW-only classifiers.

These are impressive results. **But what if we are given new papers to classify?**
These are impressive results. **But what if our citation network grows? What happens when new papers need to be classified?**

### Node2Vec limitations: dynamic networks

Expand Down Expand Up @@ -282,9 +281,9 @@ evaluate(embeddings, ds.y)

The results are slightly worse than the results we got by combining Node2Vec with BoW features. But the reason we are evaluating GraphSAGE is that Node2Vec's inability to easily accommodate to dynamic networks. GraphSAGE embeddings perform well on our classification task _and_ is able to embed completely new nodes as well. When your use case involves new nodes or nodes that evolve, an induction model like GraphSAGE may be a better choice.

## Using better node representations: LLM
## Using better node representations than BoW: LLM

Bag-of-Words representation is a simple and easy way of embedding text documents, but it comes with limitations: because it treats words as contextless, it doesn't capture semantic meaning, and therefore performs less well (on classification and ) article relatedness...
In addition to not being able to represent network structure, BoW vectors - because they treat words as contextless occurrences, merely in terms of their frequency - can't capture semantic meaning, and therefore performs less well (on classification and ) article relatedness... tasks than approaches that can do semantic embedding.
Here is a summary of what we've found so far using BoW representations of our citation network.

| Metric | BoW | Node2Vec | Node2Vec+BoW | GraphSAGE+BoW(?) |
Expand All @@ -295,6 +294,8 @@ Here is a summary of what we've found so far using BoW representations of our ci
(...All of the improvements we experienced above, BoW alone, Node2Vec + BoW, and GraphSAGE (+ BoW) can be improved further using LLM embeddings, because they excel in capturing semantic meaning...

We used the `all-mpnet-base-v2` model available on [Hugging Face](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) for embedding the title and abstract of each paper.
...everything is done in exactly the same way as with the BoW features, we just simply replace them with the LLM features..

The results obtained with LLM only, Node2Vec combined with LLM and GraphSAGE trained on LLM features can be found in the following table along with the relative improvement compared to using the BoW features:

| Metric | LLM | Node2Vec | GraphSAGE |
Expand Down
Loading