This is from Topcoder SRM 605, Division 2, 1000-point problem.
Problem Statement.
# N K Answer - -- -- ------- 0 2 1 4 -- See below for an explanation 1 3 1 8 2 4 2 44 3 10 10 184756Explanation of Example 0: There are six ways to partition the numbers { 1, 2, 3, 4 } into equal sized sets:
Let's start with a really brain-dead recursion. We will maintain two vectors in our class -- A and B. And we'll write a recursive dynamic programming function:
int DP(int next); |
This will return the number of ways to put the numbers from next to N*2 into the two vectors subject to the rules. Since it is recursive, it will need to have a base case -- that will be when next > 2*N, which means that A and B have N elements each, in ascending order, and |A[i]-B[i]| ≤ K. In that case, return 1.
When we are not in the base case, what we do is test to see whether we can put next onto A. If we can, we do so, call DP(next+1) and add the result to our total. Then we remove the element from A, and do the same thing with B. You'll note, this is pretty much identical to how we solve Sudoku in CS202, and earlier in CS302. For each element, we simply try it one way, and then undo what we've done and try the other way.
We initially call DP() with next equal to one, and both A and B empty. Here's the code, in src/alien-1.cpp. This code has a main() which lets you set N and K from the command line.
class AlienAndSetDiv2 { public: int N, K; vector <int> A; vector <int> B; int DP(int next); int getNumber(int n, int k); }; /* DP is a very simple procedure, which is like the Sudoku problem from CS202, (and CS302). It tests to see if it can insert the number onto A, and if so, it does so and makes a recursive call to DP(next+1) to count how many solutions start this way. It does the same thing with B. Its base case is when we are at 2N+1. In that case, we've come up with an answer, and we'll return one. */ int AlienAndSetDiv2::DP(int next) { int total; int i; /* Base case -- if we have inserted all numbers from 1 to 2n, then we're done, and can return 1. */ if (next == 2*N+1) return 1; /* Otherwise, we're going to count solutions with recursion. */ total = 0; /* First, we test to see if we can insert the number onto A. If so, we do so, make a recursive call, and then remove the number from A. */ if ((A.size() >= B.size() && A.size() < N) || (A.size() < B.size() && next - B[A.size()] <= K)) { A.push_back(next); total += DP(next+1); A.pop_back(); } /* We do the same thing with B. */ if ((B.size() >= A.size() && B.size() < N) || (B.size() < A.size() && next - A[B.size()] <= K)) { B.push_back(next); total += DP(next+1); B.pop_back(); } /* When we're done, we print next, A and B, and the total. */ printf("Next:%d A:{", next); for (i = 0; i < A.size(); i++) printf("%s%d", (i == 0) ? "" : ",", A[i]); printf("} B:{"); for (i = 0; i < B.size(); i++) printf("%s%d", (i == 0) ? "" : ",", B[i]); printf("} -- %d\n", total); /* And return the total. */ return total; } /* getNumber() sets N and K in the class. It clears the vectors A and B (which will probabaly be empty anyway), and then calls DP(1) to start inserting numbers with 1. */ int AlienAndSetDiv2::getNumber(int n, int k) { N = n; K = k; A.clear(); B.clear(); return DP(1); } |
Take a look at those if statements. They have two parts:
UNIX> bin/alien-1 2 1 Next:3 A:{1,2} B:{} -- 0 Next:4 A:{1,3} B:{2} -- 1 Next:4 A:{1} B:{2,3} -- 1 Next:3 A:{1} B:{2} -- 2 Next:2 A:{1} B:{} -- 2 Next:4 A:{2,3} B:{1} -- 1 Next:4 A:{2} B:{1,3} -- 1 Next:3 A:{2} B:{1} -- 2 Next:3 A:{} B:{1,2} -- 0 Next:2 A:{} B:{1} -- 2 Next:1 A:{} B:{} -- 4 4 UNIX> bin/alien-1 3 1 | tail -n 1 8 UNIX> bin/alien-1 4 2 | tail -n 1 44 UNIX> bin/alien-1 10 10 | tail -n 1 184756 UNIX> time sh -c "bin/alien-1 10 10 | tail -n 1" 184756 2.782u 0.035s 0:01.45 193.7% 0+0k 20+0io 0pf+0w UNIX>Well, the good thing is that we're getting correct answers. It's not a bad idea to trace through that first output to make sure that everything makes sense. The bad thing is that we're getting some exponential blow-up, and there's no way we're going to get the N=50, K=10 case to run in time.
int AlienAndSetDiv2::DP(int next) { long long total; /* Base case -- if we have inserted all numbers from 1 to 2n, then we're done, and can return 1. */ if (next == 2*N+1) return 1; /* Otherwise, we're going to count solutions with recursion. */ total = 0; /* If the sets are the same size, just push next onto A, and multiply the answer by two. */ if (A.size() == B.size()) { A.push_back(next); total += 2*DP(next+1); A.pop_back(); } /* Otherwise, if there's room on A, try the value on A. */ if (A.size() > B.size() && A.size() < N) { A.push_back(next); total += DP(next+1); A.pop_back(); } /* We only push onto B if A is bigger, and if the value is legal. */ if (B.size() < A.size() && next - A[B.size()] <= K) { B.push_back(next); total += DP(next+1); B.pop_back(); } ..... /* Return the total, mod 1,000,000,007. */ return total % 1000000007; } |
When we run it, it's making far fewer calls, and example 3 (N = K = 10) is much faster:
UNIX> bin/alien-2 2 1 Next:3 A:{1,2} B:{} -- 0 Next:4 A:{1,3} B:{2} -- 1 Next:3 A:{1} B:{2} -- 2 Next:2 A:{1} B:{} -- 2 Next:1 A:{} B:{} -- 4 4 UNIX> bin/alien-2 3 1 | tail -n 1 8 UNIX> bin/alien-2 4 2 | tail -n 1 44 UNIX> bin/alien-2 10 10 | tail -n 1 184756 UNIX> time sh -c "bin/alien-2 10 10 | tail -n 1" 184756 0.348u 0.008s 0:00.19 178.9% 0+0k 16+0io 0pf+0w UNIX>Unfortunately, trying N = 50 and K = 10 doesn't complete, so we still have some work to do.
This observation is really nice, because it gives us a way to limit complexity, and to memoize. To be specific, we don't need to keep track of B any more, because there will never be any unmatched values on B. Let's instead turn A into a deque, and when we would previously call B.push_back() to put a value on B, we'll now call pop_front() on A. We'll have to store that value, so that we can put it back after making the recursive call.
Also, we can turn next and A into a string and memoize on the string.
We do both of these in. src/alien-3.cpp:
/* We've removed B, changed A into a deque, and added a memoization Cache. */ class AlienAndSetDiv2 { public: int N, K; deque <int> A; int DP(int next); int getNumber(int n, int k); map < string, int > Cache; }; int AlienAndSetDiv2::DP(int next) { long long total; int i ; int saved; string key; char buf[8]; /* Base case -- if we have inserted all numbers from 1 to 2n, then we're done. Return 1. */ if (next == 2*N+1) return 1; /* Create a memoization key from next and A, and check the cache. */ sprintf(buf, "%d", next); key = buf; for (i = 0; i < A.size(); i++) { sprintf(buf, "-%d", A[i]); key += buf; } if (Cache.find(key) != Cache.end()) return Cache[key]; /* Otherwise, we're going to count solutions with recursion. */ total = 0; /* If the sets are the same size, that means that A is empty. Push the value onto A and multiply the answer by two. */ if (A.size() == 0) { A.push_back(next); total += 2*DP(next+1); A.pop_back(); } /* Otherwise, if there's room on A, try the value on A. We have to calcluate "room on A" differently, since we are not keeping matches. We need to make sure that if we push the value on A, that there are enough values left on B. */ if (A.size() > 0 && (N*2+1 - next) > A.size()) { A.push_back(next); total += DP(next+1); A.pop_back(); } /* Now, instead of pushing the value on B, we'll remove the first value from A, and put it back when the recursion is done. */ if (A.size() > 0 && next - A[0] <= K) { saved = A[0]; A.pop_front(); total += DP(next+1); A.push_front(saved); } /* Return the total, mod 1,000,000,007. */ Cache[key] = total % 1000000007; return Cache[key]; } |
I don't show the code, but I print the memoization cache after calling DP(1). Here are the examples:
UNIX> bin/alien-3 2 1 1 : 4 2-1 : 2 3 : 2 3-1-2 : 0 4-3 : 1 4 UNIX> bin/alien-3 3 1 | tail -n 1 8 UNIX> bin/alien-3 4 2 | tail -n 1 44 UNIX> bin/alien-3 10 10 | tail -n 1 184756 UNIX> time sh -c "bin/alien-3 10 10 | tail -n 1" 184756 0.018u 0.006s 0:00.02 50.0% 0+0k 7+0io 0pf+0w UNIX> time sh -c "bin/alien-3 50 10 | tail -n 1" 153890414 4.876u 0.044s 0:04.70 104.4% 0+0k 0+0io 0pf+0w UNIX>Well, we're now fast enough that we can do the worst case or N equal 50 and K equal 10 and it completes. But it's still not fast enough.
So, let's fix that in src/alien-4.cpp. Here's the relevant code in DP():
/* Otherwise, if there's room on A, and next is not too big, try the value on A. We have to calcluate "room on A" differently, since we are not keeping matches. */ if (A.size() > 0 && next-A[0] < K && (N*2+1 - next) > A.size()) { A.push_back(next); total += DP(next+1); A.pop_back(); } |
When we try this, we see that it has eliminated some cases from the previous solution. In particular, the cache key "3-1-2" corresponds to next=3 and A = { 1, 2 }. This is avoided in alien-4, because it makes it impossible to put a value onto B:
UNIX> bin/alien-4 2 1 1 : 4 2-1 : 2 3 : 2 4-3 : 1 4 UNIX> bin/alien-4 3 1 | tail -n 1 8 UNIX> bin/alien-4 4 2 | tail -n 1 44 UNIX> bin/alien-4 10 10 | tail -n 1 184756 UNIX> bin/alien-4 50 10 | tail -n 1 153890414 UNIX> time sh -c "bin/alien-4 50 10 | tail -n 1" 153890414 0.389u 0.008s 0:00.38 100.0% 0+0k 0+0io 0pf+0w UNIX>Even better yet, it solves the biggest case in 0.389 seconds -- we can submit to Topcoder!!!!
Instead of storing A as a set of values, store it as a set of differences from next. How does that work? Whenever you push a value onto A, you don't push next, but instead you push 0. Then, when you call DP(next+1), you increase every value in A by one.
That sounds inefficient, doesn't it? Well, it's not if you use bit arithmetic to store A. When you "push 0", you simply set A to (A|1). When you "increase every value in A by one," you set A to (A << 1). When you want to test to see whether you can legally push a value onto A, you can simply check to see if A is bigger than (1 << K).
The only operation that is kind of a pain is when you previously called A.pop_front(). That takes a little work. However, since K is capped at 10, A will never be bigger than 2048. You can even turn next and A into a single integer for memoization: key = (next << 12) | A.
How fun is that?
The code is in src/alien-5.cpp, and after the initial call to DP() returns, we print the cache, extracting next and A from the memoization key. I don't include it here, but it is commented, and if you want to have some fun, give it a read. Then try to code it up yourself!
To time this vs. src/alien-4.cpp, I removed the print statements:
UNIX> g++ -O3 -o bin/alien-4 src/alien-4.cpp UNIX> g++ -O3 -o bin/alien-5 src/alien-5.cpp UNIX> time bin/alien-4 50 10 153890414 0.093u 0.001s 0:00.09 100.0% 0+0k 0+0io 0pf+0w UNIX> time bin/alien-5 50 10 153890414 0.017u 0.001s 0:00.01 100.0% 0+0k 0+0io 0pf+0w UNIX>I have an even faster version in src/alien-6.cpp -- this one uses a two-dimensional vector for the cache. I don't have it commented yet, but I suspect no one will be reading it anyway. Someday when I'm bored, I'll comment it:
UNIX> time bin/alien-6 50 10 153890414 0.001u 0.000s 0:00.00 0.0% 0+0k 0+0io 0pf+0w UNIX>
If you want a fun challenge that requires you to use bit arithmetic, try the Division 1, 500-point version of this problem!