entity_embedding.py


Overview

The entity_embedding.py module provides functionality to generate node embeddings for graph data structures using the Node2Vec algorithm. Node embeddings transform nodes in a graph into dense vector representations, capturing structural and relational information. This file is designed to work primarily with NetworkX graphs (nx.Graph or nx.DiGraph) and leverages the graphrag and graspologic libraries for embedding computations.

Key features include:


Classes and Functions

NodeEmbeddings (dataclass)

A simple container class to hold node identifiers and their corresponding embeddings.

Attributes:

Name

Type

Description

nodes

list[str]

List of node identifiers.

embeddings

np.ndarray

2D array where each row corresponds to a node's embedding vector.

Usage Example:

embeddings = NodeEmbeddings(
    nodes=["node1", "node2", "node3"],
    embeddings=np.array([[0.1, 0.2], [0.4, 0.5], [0.7, 0.8]])
)
print(embeddings.nodes)       # ['node1', 'node2', 'node3']
print(embeddings.embeddings)  # array with shape (3, 2)

embed_node2vec

def embed_node2vec(
    graph: nx.Graph | nx.DiGraph,
    dimensions: int = 1536,
    num_walks: int = 10,
    walk_length: int = 40,
    window_size: int = 2,
    iterations: int = 3,
    random_seed: int = 86,
) -> NodeEmbeddings

Generate node embeddings for a graph using the Node2Vec algorithm.

Parameters:

Name

Type

Default

Description

graph

nx.Graph or nx.DiGraph

The input graph to embed. Can be undirected or directed.

dimensions

int

1536

Dimensionality of the embedding vectors.

num_walks

int

10

Number of random walks to start at each node.

walk_length

int

40

Length of each random walk.

window_size

int

2

Window size for the skip-gram model in Node2Vec.

iterations

int

3

Number of iterations (epochs) for training the skip-gram model.

random_seed

int

86

Seed for reproducibility of random walks and training.

Returns:

Behavior and Implementation Details:

Usage Example:

import networkx as nx
graph = nx.karate_club_graph()
embeddings = embed_node2vec(graph, dimensions=128)
print(embeddings.nodes)        # List of node IDs
print(embeddings.embeddings.shape)  # (number_of_nodes, 128)

run

def run(graph: nx.Graph, args: dict[str, Any]) -> dict

Primary entry point function to generate embeddings from a graph with configurable parameters.

Parameters:

Name

Type

Description

graph

nx.Graph

Input graph on which embeddings are computed.

args

dict[str, Any]

Dictionary of options controlling embedding generation.

Expected keys in args and their defaults:

Key

Default

Description

use_lcc

True

Whether to extract the largest connected component before embedding.

dimensions

1536

Embedding vector size.

num_walks

10

Number of random walks per node.

walk_length

40

Length of each walk.

window_size

2

Context window size for skip-gram.

iterations

3

Number of training iterations.

random_seed

86

Seed for random operations.

Returns:

Implementation Details:

Usage Example:

import networkx as nx

graph = nx.karate_club_graph()
args = {
    "use_lcc": True,
    "dimensions": 128,
    "num_walks": 20,
    "walk_length": 50,
    "window_size": 5,
    "iterations": 5,
    "random_seed": 42,
}

embedding_dict = run(graph, args)
print(embedding_dict["0"])  # Embedding vector for node '0' as a list of floats

Important Implementation Details and Algorithms


Interaction with Other System Components


Visual Diagram

classDiagram
    class NodeEmbeddings {
        +nodes: list[str]
        +embeddings: np.ndarray
    }

    class embed_node2vec {
        +graph: nx.Graph | nx.DiGraph
        +dimensions: int
        +num_walks: int
        +walk_length: int
        +window_size: int
        +iterations: int
        +random_seed: int
        +return: NodeEmbeddings
    }

    class run {
        +graph: nx.Graph
        +args: dict[str, Any]
        +return: dict[str, list[float]]
    }

    run --> embed_node2vec : calls
    embed_node2vec --> NodeEmbeddings : returns

Summary

The entity_embedding.py file is a focused utility for generating Node2Vec embeddings from graph data with optional preprocessing to extract the largest connected component. It provides a clean API for embedding generation and returns results in user-friendly formats. The module is designed for easy integration into graph analytics pipelines, leveraging state-of-the-art embedding techniques and ensuring reproducibility and configurability.