entity_embedding.py
Overview
The entity_embedding.py module provides functionality to generate node embeddings for graph data structures using the Node2Vec algorithm. Node embeddings transform nodes in a graph into dense vector representations, capturing structural and relational information. This file is designed to work primarily with NetworkX graphs (nx.Graph or nx.DiGraph) and leverages the graphrag and graspologic libraries for embedding computations.
Key features include:
Extracting the largest connected component (LCC) of a graph to ensure embeddings are meaningful.
Generating Node2Vec embeddings with configurable hyperparameters.
Packaging the embeddings with node identifiers for easy downstream use.
Classes and Functions
NodeEmbeddings (dataclass)
A simple container class to hold node identifiers and their corresponding embeddings.
Attributes:
Name | Type | Description |
|---|---|---|
|
| List of node identifiers. |
|
| 2D array where each row corresponds to a node's embedding vector. |
Usage Example:
embeddings = NodeEmbeddings(
nodes=["node1", "node2", "node3"],
embeddings=np.array([[0.1, 0.2], [0.4, 0.5], [0.7, 0.8]])
)
print(embeddings.nodes) # ['node1', 'node2', 'node3']
print(embeddings.embeddings) # array with shape (3, 2)
embed_node2vec
def embed_node2vec(
graph: nx.Graph | nx.DiGraph,
dimensions: int = 1536,
num_walks: int = 10,
walk_length: int = 40,
window_size: int = 2,
iterations: int = 3,
random_seed: int = 86,
) -> NodeEmbeddings
Generate node embeddings for a graph using the Node2Vec algorithm.
Parameters:
Name | Type | Default | Description |
|---|---|---|---|
|
| — | The input graph to embed. Can be undirected or directed. |
|
| 1536 | Dimensionality of the embedding vectors. |
|
| 10 | Number of random walks to start at each node. |
|
| 40 | Length of each random walk. |
|
| 2 | Window size for the skip-gram model in Node2Vec. |
|
| 3 | Number of iterations (epochs) for training the skip-gram model. |
|
| 86 | Seed for reproducibility of random walks and training. |
Returns:
NodeEmbeddings: A dataclass instance containing the list of nodes and their corresponding embedding vectors.
Behavior and Implementation Details:
The function delegates embedding computation to
graspologic.embed.node2vec_embed, which internally performs:Generating random walks over the graph.
Training a skip-gram model to learn node representations.
The embeddings are returned alongside nodes in a tuple and then encapsulated into
NodeEmbeddings.
Usage Example:
import networkx as nx
graph = nx.karate_club_graph()
embeddings = embed_node2vec(graph, dimensions=128)
print(embeddings.nodes) # List of node IDs
print(embeddings.embeddings.shape) # (number_of_nodes, 128)
run
def run(graph: nx.Graph, args: dict[str, Any]) -> dict
Primary entry point function to generate embeddings from a graph with configurable parameters.
Parameters:
Name | Type | Description |
|---|---|---|
|
| Input graph on which embeddings are computed. |
|
| Dictionary of options controlling embedding generation. |
Expected keys in args and their defaults:
Key | Default | Description |
|---|---|---|
|
| Whether to extract the largest connected component before embedding. |
| 1536 | Embedding vector size. |
| 10 | Number of random walks per node. |
| 40 | Length of each walk. |
| 2 | Context window size for skip-gram. |
| 3 | Number of training iterations. |
| 86 | Seed for random operations. |
Returns:
dict: A dictionary mapping node identifiers (sorted lexically) to their embedding vectors represented as lists.
Implementation Details:
If
use_lccis enabled, the graph is reduced to its largest stable connected component usingstable_largest_connected_componentfrom thegraphraglibrary. This step ensures embeddings are computed on a well-connected subgraph, improving quality.Calls
embed_node2vecwith parameters fromargs.The output embeddings and nodes are zipped, sorted by node identifier, and returned as a dictionary for easy lookup.
Usage Example:
import networkx as nx
graph = nx.karate_club_graph()
args = {
"use_lcc": True,
"dimensions": 128,
"num_walks": 20,
"walk_length": 50,
"window_size": 5,
"iterations": 5,
"random_seed": 42,
}
embedding_dict = run(graph, args)
print(embedding_dict["0"]) # Embedding vector for node '0' as a list of floats
Important Implementation Details and Algorithms
Largest Connected Component (LCC):
The module optionally extracts the largest connected component of the graph before embedding. This step removes isolated nodes or disconnected subgraphs, which could degrade embedding quality. The functionstable_largest_connected_componentfrom thegraphraglibrary is used, which guarantees stability and consistency in the LCC selection process.Node2Vec Algorithm:
The core embedding approach is Node2Vec, which performs biased random walks to sample node neighborhoods, followed by training a skip-gram model to learn embeddings that capture structural equivalences and community membership. Parameters likewalk_length,num_walks, andwindow_sizedirectly influence the quality and nature of embeddings.Integration with
graspologic:
The actual embedding logic is delegated tograspologic.embed.node2vec_embed, a well-optimized implementation supporting GPU acceleration and scalable embeddings.Deterministic Output:
The use of arandom_seedparameter ensures that embeddings are reproducible given the same graph and parameters.
Interaction with Other System Components
graphrag.general.leiden:
The functionstable_largest_connected_componentis imported from this module to ensure the graph is reduced to its largest stable connected subgraph prior to embedding.graspologic:
Provides the underlying embedding algorithmnode2vec_embed. This is a critical dependency for embedding computations.Graph Input/Output:
The module expects NetworkX graph objects as input and returns embeddings as dictionaries keyed by node identifiers or asNodeEmbeddingsdataclasses, allowing easy integration with downstream tasks like node classification, clustering, or visualization.
Visual Diagram
classDiagram
class NodeEmbeddings {
+nodes: list[str]
+embeddings: np.ndarray
}
class embed_node2vec {
+graph: nx.Graph | nx.DiGraph
+dimensions: int
+num_walks: int
+walk_length: int
+window_size: int
+iterations: int
+random_seed: int
+return: NodeEmbeddings
}
class run {
+graph: nx.Graph
+args: dict[str, Any]
+return: dict[str, list[float]]
}
run --> embed_node2vec : calls
embed_node2vec --> NodeEmbeddings : returns
Summary
The entity_embedding.py file is a focused utility for generating Node2Vec embeddings from graph data with optional preprocessing to extract the largest connected component. It provides a clean API for embedding generation and returns results in user-friendly formats. The module is designed for easy integration into graph analytics pipelines, leveraging state-of-the-art embedding techniques and ensuring reproducibility and configurability.