BZOJ 1468: Tree

Description

给你一棵TREE,以及这棵树上边的距离.问有多少对点它们两者间的距离小于等于K

Input

N(n<=40000) 接下来n-1行边描述管道,按照题目中写的输入 接下来是k

Output

一行,有多少对点之间的距离小于等于k

Sample Input

7

1 6 13

6 3 9

3 5 7

4 1 3

2 4 20

4 7 2

10

Sample Output

5

HINT

Source

LTC男人八题系列

Solution

点分治。

设分支到树Trt,Trt中节点为N,Trt的重心(根)为rt,dis[i]表示以rt为根时,rt到i的距离。此时答案Ans(rt)为:Sigma (i∈son(rt)) Ans(i) +Sigma (i∈T,j∈T)[ dis[i][j]<=K]-Sigma (v∈son(rt)) Sigma (i∈Tv,j∈Tv)[ dis[i][j]<=K]。

后面的减数是减去一种特殊的情况,即dis[rt][i]+dis[rt][j]<=K且i和j属于同一个rt的子树。

注意点分治的实现。

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long LL;
const int INF = 1009999999;
const int maxn = 50005;
int getint() {
	int r = 0, k = 1; char c = getchar();
	for (; '0' > c || c > '9'; c = getchar()) if (c == '-') k = -1;
	for (; '0' <= c && c <= '9'; c = getchar()) r = r * 10 - '0' + c;
	return r * k;
}
struct edge_type { int to, next, w, baned; } edge[maxn << 1];
int cnte = 1, h[maxn];
void ins(int x, int y, int z) {
	edge[++cnte].to = y;
	edge[cnte].next = h[x];
	edge[cnte].w = z;
	edge[cnte].baned = 0;
	h[x] = cnte;
}
int N, K, siz[maxn];
void init() {
	N = getint();
	int x, y, z;
	for (int i = 1; i < N; ++i) {
		x = getint(); y = getint(); z = getint();
		ins(x, y, z);
		ins(y, x, z);
	}
	K = getint();
}      
int Mag, HSK;
int dfs(int now, int father, int size) {
	int tmp = 0; siz[now] = 1;
	for (int i = h[now]; i; i = edge[i].next) {
		if (edge[i].baned == 0 && edge[i].to != father) {
			siz[now] += dfs(edge[i].to, now, size);
			tmp = max(tmp, siz[edge[i].to]);
		}
	}
	tmp = max(tmp, size - siz[now]);
	if (tmp < Mag) {
		Mag = tmp;
		HSK = now;
	}
	return siz[now];
}
int find_root(int now, int size) {
	Mag = INF;
	dfs(now, -1, size);
	return HSK;
}
int dis[maxn], dst[maxn], ptr;
void get_dis(int now, int father, int ndis) {
	dis[now] = ndis;
	for (int i = h[now]; i; i = edge[i].next)
		if (edge[i].baned == 0 && edge[i].to != father)
			get_dis(edge[i].to, now, ndis + edge[i].w);
}
void DFS(int now, int father) {
	dst[++ptr] = dis[now];
	for (int i = h[now]; i; i = edge[i].next)
		if (edge[i].baned == 0 && edge[i].to != father)
			DFS(edge[i].to, now);
}
int calc(int rt) {
	int ret = 0; ptr = 0;
	DFS(rt, -1);
	sort(dst+1, dst+ptr+1);
	for (int i = 1, j = ptr; dst[i] <= K && i <= ptr; ++i) {
		for (; dst[j] + dst[i] > K && j; --j);
		ret += j;
	}
	return ret;
}
int Ans = 0;
void Solve(int now, int size) {
	int rt = find_root(now, size);
	get_dis(rt, -1, 0);
	Ans += calc(rt);
	for (int i = h[rt]; i; i = edge[i].next) {
		if (edge[i].baned == 0) {
			edge[i^1].baned = edge[i].baned = 1;
			Ans -= calc(edge[i].to);
			Solve(edge[i].to, siz[edge[i].to]);
		}
	}
	return;
}
int main() {
	init();
	Solve(1, N);
	printf("%d", (Ans - N) >> 1);
	return 0;
}