求lca的类

python写法:

import queue


class Graph(object):
    def add_edge(self, x, y, z):
        self.tot += 1
        self.ver[self.tot] = y
        self.edge[self.tot] = z
        self.nex[self.tot] = self.head[x]
        self.head[x] = self.tot

    def __init__(self):
        self.N = 200007
        self.f = [[0 for j in range(20)] for i in range(self.N)]
        self.d = [0 for i in range(self.N)]
        self.dist = [0 for i in range(self.N)]
        self.nex = [0 for i in range(self.N)]
        self.edge = [0 for i in range(self.N)]
        self.head = [0 for i in range(self.N)]
        self.ver = [0 for i in range(self.N)]
        self.tot = 0
        self.n, self.m = map(int, input().split())
        from math import log
        self.t = int(log(self.n) / log(2)) + 1
        for i in range(self.m):
            x, y, z = map(int, input().split())
            self.add_edge(x, y, z)
            self.add_edge(y, x, z)

    def bfs(self):
        self.d[1] = 1
        self.dist[1] = 0
        q = queue.Queue()
        q.put(1)
        while not q.empty():
            x = q.get()
            i = self.head[x]
            while i:
                y = self.ver[i]
                if self.d[y]:
                    i = self.nex[i]
                    continue
                self.d[y] = self.d[x] + 1
                self.dist[y] = self.dist[x] + self.edge[i]
                self.f[y][0] = x
                for j in range(1, self.t + 1):
                    self.f[y][j] = self.f[self.f[y][j - 1]][j - 1]
                q.put(y)
                i = self.nex[i]

    def lca(self, x, y):
        if self.d[x] >= self.d[y]:
            x, y = y, x
        for i in range(self.t, -1, -1):
            if self.d[y] >= self.d[x]:
                y = self.f[y][i]
        if x == y:
            return x
        for i in range(self.t, -1, -1):
            if self.d[y] != self.d[x]:
                y = self.f[y][i]
                x = self.f[x][i]
        return self.f[x][0]

    def get_distance(self, x, y):
        return self.dist[x] + self.dist[y] - 2 * self.dist[self.lca(x, y)]


graph = Graph()
graph.bfs()
m = eval(input())
for i in range(m):
    x, y = map(int, input().split())
    print(graph.get_distance(x, y))
View Code

java写法:

package com.company;
import java.util.LinkedList;
import java.util.Queue;
import java.util.Scanner;
class Graph {
    private final static int N = 200007;
    private static int[] ver = new int[N];
    private static int[] nex = new int[N];
    private static int[] head = new int[N];
    private static int[] edge = new int[N];
    private static int[] d = new int[N];
    private static int[] dist = new int[N];
    private static int[][] f = new int[N][20];
    private static int n, m, tot, t;
    void add_edge(int x, int y, int z) {
        ver[++tot] = y;
        nex[tot] = head[x];
        edge[tot] = z;
        head[x] = tot;
    }
    void init_graph() {
        Scanner cin = new Scanner(System.in);
        n = cin.nextInt();
        m = cin.nextInt();
        t = (int)(Math.log(n) / Math.log(2)) + 1;
        tot = 0;
        for (int i = 0; i < N; i++) {
            nex[i] = 0;
            for (int j = 0; j < 20; j++)
                f[i][j] = 0;
        }
        for (int i = 0; i < m; i++) {
            int x = cin.nextInt();
            int y = cin.nextInt();
            int z = cin.nextInt();
            add_edge(x, y, z);
            add_edge(y, x, z);
        }
    }
    void bfs() {
        Queue<Integer> q = new LinkedList<Integer>();
        d[1] = 1;
        dist[1] = 0;
        q.offer(1);
        while (q.size() > 0) {
            int x = q.poll();
            for (int i = head[x]; i != 0; i = nex[i]) {
                int y = ver[i];
                if (d[y] != 0) continue;
                d[y] = d[x] + 1;
                dist[y] = dist[x] + edge[i];
                f[y][0] = x;
                for (int j = 1; j <= t; j++)
                    f[y][j] = f[f[y][j - 1]][j - 1];
                q.offer(y);
            }
        }
    }
    int get_depth(int x) {
        return d[x];
    }
    int lca(int x, int y) {
        if (d[x] > d[y]) { int t = x; x = y; y = t;}
        for (int i = t; i >= 0; i--) { if (d[f[y][i]] >= d[x]) y = f[y][i];}
        if (x == y) return x;
        for (int i = t; i >= 0; i--) { if (f[x][i] != f[y][i]) {x = f[x][i]; y = f[y][i];};}
        return f[x][0];
    }
    int get_distance(int x, int y) {
        return dist[x] + dist[y] - 2 * dist[lca(x, y)];
    }
}
public class Main {
    public static void main(String[] args) {
        Scanner cin = new Scanner(System.in);
        Graph graph = new Graph();
        graph.init_graph();
        graph.bfs();
        int m = cin.nextInt();
        for (int k = 0; k < m; k++) {
            int x = cin.nextInt();
            int y = cin.nextInt();
            System.out.println(graph.get_distance(x, y));
        }
    }
}
View Code
原文地址:https://www.cnblogs.com/SwiftAC/p/12692567.html