1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
| #include <cstdio>
#include <algorithm>
using namespace std;
typedef long long LL;
const int INF = 1<<30;
int getint() {
int r = 0, k = 1; char c = getchar();
for (; '0' > c || c > '9'; c = getchar()) if (c == '-') k = -1;
for (; '0' <= c && c <= '9'; c = getchar()) r = r * 10 - '0' + c;
return r * k;
}
struct edge_type {
int to, next;
} edge[2333333];
int cnte = 1, h[2333333];
bool vis[2333333];
void ins(int x, int y) {
edge[++cnte].to = y;
edge[cnte].next = h[x];
h[x] = cnte;
edge[++cnte].to = x;
edge[cnte].next = h[y];
h[y] = cnte;
}
int h1, h2, fa[2333333];
void dfs(int now, int father) {
vis[now] = true; fa[now] = father;
for (int i = h[now]; i; i = edge[i].next) {
if (edge[i].to == father) continue;
if (vis[edge[i].to]) {
if (fa[edge[i].to] != now) {
h1 = edge[i].to;
h2 = now;
}
continue;
}
dfs(edge[i].to, now);
}
}
LL F[2333333][2], v[2333333];
void dfs1(int now, int father, int h1, int h2) {
F[now][1] = v[now]; F[now][0] = 0;
for (int i = h[now]; i; i = edge[i].next) {
if (edge[i].to == father) continue;
if ((now == h1 && edge[i].to == h2) || (now == h2 && edge[i].to == h1)) continue;
dfs1(edge[i].to, now, h1, h2);
F[now][0] += max(F[edge[i].to][0], F[edge[i].to][1]);
F[now][1] += F[edge[i].to][0];
}
}
bool v2[2333333];
void dfs2(int now) {
v2[now] = true;
F[now][1] = v[now]; F[now][0] = 0;
for (int i = h[now]; i; i = edge[i].next) {
if (v2[edge[i].to]) continue;
dfs2(edge[i].to);
F[now][0] += max(F[edge[i].to][0], F[edge[i].to][1]);
F[now][1] += F[edge[i].to][0];
}
}
LL dp1(int rt, int rt1) {
dfs1(rt, -1, rt, rt1);
return F[rt][0];
}
LL dp2(int rt) {
dfs2(rt);
return max(F[rt][0], F[rt][1]);
}
LL solve(int rt) {
h1 = h2 = 0; dfs(rt, -1);
if (h1 != 0) return max(dp1(h1, h2), dp1(h2, h1));
else return dp2(rt);
}
int main() {
int n = getint(), x;
for (int i = 1; i <= n; ++i) {
v[i] = getint();
x = getint();
ins(x, i);
}
LL ans = 0;
for (int i = 1; i <= n; ++i)
if (!vis[i])
ans += solve(i);
printf("%lld", ans);
return 0;
}
|