CODEVS 4203 山区建小学

Description

政府在某山区修建了一条道路,恰好穿越总共m个村庄的每个村庄一次,没有回路或交叉,任意两个村庄只能通过这条路来往。已知任意两个相邻的村庄之间的距离为di(为正整数),其中,0 < i < m。为了提高山区的文化素质,政府又决定从m个村中选择n个村建小学(设 0 < n < = m < 500 )。请根据给定的m、n以及所有相邻村庄的距离,选择在哪些村庄建小学,才使得所有村到最近小学的距离总和最小,计算最小值。

Input

第1行为m和n,其间用空格间隔

第2行为(m-1) 个整数,依次表示从一端到另一端的相邻村庄的距离,整数之间以空格间隔。

例如

10 3

2 4 6 5 2 4 3 1 3

表示在10个村庄建3所学校。第1个村庄与第2个村庄距离为2,第2个村庄与第3个村庄距离为4,第3个村庄与第4个村庄距离为6,…,第9个村庄到第10个村庄的距离为3。

Output

各村庄到最近学校的距离之和的最小值。

Sample Input

10 2

3 1 3 1 1 1 1 1 3

Sample Output

18

Solution

先说些题外话,这是今天SCX大犇出的NOIP模拟题T2(难度还可以的)。

一开始我想了个很显然的DP:

\[DP_{i,j}=DP_{k,j-1}+Q_{k+1,i}\]

\(Q_{L,R}\)表示的是从L到R村庄如果建一所学校,最小距离和是多少。

然后我神奇的发现这是一个\(O(n^4)\)的DP!

天啦噜!考虑优化!

我们可以用前缀和来维护由第k个村庄到L…R村的距离,然后枚举L、R和中间断点MID,来求出\(Q_{i,j}\),然后跑DP就可以了。

显然这是一个\(O(n^3)\)的DP,对于本题来说能过了。

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <cstdio>
#include <map>
#include <algorithm>
using namespace std;

typedef long long ll;
int getint() {
	int r = 0, k = 1; char c = getchar();
	for (; c < '0' || c > '9'; c = getchar()) if (c == '-') k = -1;
	for (; '0' <= c && c <= '9'; c = getchar()) r = r * 10 - '0' + c;
	return r*k;
}
const int INF = 1000000007;
const int maxn = 505;
int n, m, a[maxn], p[maxn], que[maxn][maxn], mque[maxn][maxn];
int dp[maxn][maxn];
int dis(int i, int j) { if (j > i) return p[j] - p[i]; return p[i]-p[j];}
int main() {
	n = getint(); m = getint();
	for (int i = 1; i < n; ++i) a[i] = getint(), p[i+1] = p[i]+a[i];
	for (int i = 0; i <= n; ++i)
		for (int j = 0; j <= n; ++j)
			mque[i][j] = dp[i][j] = INF;
	for (int i = 1; i <= n; ++i)
		for (int j = 1; j <= n; ++j)
			que[i][j] = que[i][j-1]+dis(i, j);
	for (int i = 1; i <= n; ++i) {
		for (int j = 1; j <= i; ++j)
			for (int k = j; k <= i; ++k)
				mque[j][i] = min(mque[j][i], que[k][i] - que[k][j-1]);
		dp[i][i] = 0;
		dp[i][1] = mque[1][i];
	}
	for (int i = 1; i <= n; ++i)
		for (int j = 2; j <= m; ++j)
				for (int k = 1; k < i; ++k)
					dp[i][j] = min(dp[i][j], dp[k][j-1]+mque[k+1][i]);
	printf("%d", dp[n][m]);
	return 0;
}