import torch
def drop_row_and_col(A: torch.Tensor, i) -> torch.Tensor:
"""
Remove the i-th row and i-th column from a 2D square tensor A.
Args:
A (torch.Tensor): 2D square tensor of shape (n, n)
i List[int]: index of row and column to remove (0-based)
Returns:
torch.Tensor: new tensor of shape (n-1, n-1)
"""
if A.ndim != 2 or A.shape[0] != A.shape[1]:
raise ValueError("Input must be a 2D square tensor.")
mask = torch.ones(A.shape[0], dtype=torch.bool, device=A.device)
mask[i] = False
return A[mask][:, mask]
def drop_under_index(array: torch.Tensor,i):
mask = torch.ones(array.shape[0], dtype=torch.bool, device=array.device)
mask[i]=False
return array[mask]
def get_maximal_ind_set(adj_matrix,drop_at_once = 1):
"""
Greedy vertex removal maximal independent set approximation.
One by one removes nodes with largest degree and recalculates degrees until resulting adj matrix is empty.
adj_matrix: adjacency matrix N x N
drop_at_once: how many elements drop at once. Larger values allow to speed-up computation a lot.
"""
node_indices = torch.arange(adj_matrix.shape[0],device=adj_matrix.device)
max_indep_set = adj_matrix.clone()
while True:
close_points = max_indep_set.sum(-1)
ind = close_points.argsort(descending=True)[:drop_at_once]
ind = ind[close_points[ind]>0]
if len(ind)==0:
break
node_indices=drop_under_index(node_indices,ind)
max_indep_set=drop_row_and_col(max_indep_set,ind)
return node_indices
Simplest usage is like this
adj = torch.randn((500,500))>0.5
adj[torch.arange(500),torch.arange(500)]=False
get_maximal_ind_set(adj,drop_at_once=3)
You can even do it on gpu (even a lot faster)
adj = torch.randn((500,500))>0.5
adj[torch.arange(500),torch.arange(500)]=False
get_maximal_ind_set(adj.cuda(),drop_at_once=3).cpu()
Here is simple visualization of it's work
import networkx as nx
# ==== Generate random adjacency matrix ====
torch.manual_seed(0)
n = 10
adj_matrix = (torch.rand((n, n)) > 0.5).int()
adj_matrix = torch.triu(adj_matrix, 1) # upper triangle only
adj_matrix = adj_matrix + adj_matrix.T # make symmetric (undirected graph)
adj_matrix.fill_diagonal_(0)
# ==== Find maximal independent set ====
ind_set = get_maximal_ind_set(adj_matrix, drop_at_once=1)
print("Independent set indices:", ind_set.tolist())
# ==== Visualize ====
G = nx.Graph()
G.add_nodes_from(range(n))
for i in range(n):
for j in range(i + 1, n):
if adj_matrix[i, j]:
G.add_edge(i, j)
pos = nx.spring_layout(G, seed=42) # layout for consistent visualization
# Node colors: blue if in independent set, red otherwise
node_colors = ['tab:blue' if i in ind_set else 'tab:red' for i in G.nodes()]
plt.figure(figsize=(6, 6))
nx.draw(
G,
pos,
with_labels=True,
node_color=node_colors,
node_size=600,
font_color='white',
edge_color='gray',
)
plt.title("Graph with Maximal Independent Set (blue nodes)")
plt.show()