leetcodealgorithms-templates6-graphs
 
 
 
import java.util.Map;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.ArrayList;
import java.util.Queue;
import java.util.PriorityQueue;
 
public class Prim {
    // Given a list of edges of a connected undirected graph,
    // with nodes numbered from 1 to n,
    // return a list edges making up the minimum spanning tree.
    public static List<Integer[]> mst(int[][] edges, int n) {
        Map<Integer, ArrayList<Integer[]>> adj = new HashMap<>();
        for (int i = 1; i < n + 1; i++) {
            adj.put(i, new ArrayList<Integer[]>());
        }
        for (int[] edge : edges) {
            int n1 = edge[0], n2 = edge[1], weight = edge[2];
            adj.get(n1).add(new Integer[] {n2, weight});
            adj.get(n2).add(new Integer[] {n1, weight});
        }
 
        // Initialize the heap by choosing a single node
        // (in this case 1) and pushing all its neighbors.
        Queue<int[]> minHeap = new PriorityQueue<>((n1, n2) -> (n1[0] - n2[0]));
        for (Integer[] neighbor : adj.get(1)) {
            int node = neighbor[0], weight = neighbor[1];
            minHeap.add(new int[]{weight, 1, node});
        }
 
        List<Integer[]> mst = new ArrayList<>();
        HashSet<Integer> visit = new HashSet<>();
        visit.add(1);
        while(visit.size() < n){
            int[] cur = minHeap.remove();
            int w1 = cur[0], n1 = cur[1], n2 = cur[2];
            if (visit.contains(n2)) {
                continue;
            }
            mst.add(new Integer[]{n1, n2});
            visit.add(n2);
            for (Integer[] pair: adj.get(n2)) {
                Integer neighbor = pair[0], weight = pair[1];
                if (!visit.contains(neighbor)) {
                    minHeap.add(new int[]{weight, n2, neighbor});
                }
            }
        }
        return mst;
    }
}