动态规划解题思路

动态规划原理

动态问题通常是一个问题依赖于一个或多个子问题,基于子问题的解得出一个稍大点的子问题的解。

由于基于递归的算法常常是把原问题分为多个子问题,所以动态规划常常可以写成递归调用的形式。但有的动态规划问题,就不太容易写成递归形式。

斐波那契数列

这里使用知名的斐波那契数列问题来讲解动态规划,试图使用一个例子来揭示动态规划的原理。

求解斐波那契数列,可以以递归形式写出:

def fib(n):
    if n <= 1:
        return n
    return fib(n-1) + fib(n-2)

fib(n) 依赖于 fib(n-1)fib(n-2) 这两个子问题的。在不同的问题中,找出子问题始终是动态规划中最关键的一步,找到了子问题之后,就能写出递推关系,进而以递归形式实现一个算法。

但是递归的写法存在问题,观察下图,你会发现子问题,如 fib(3),被计算了很多次。

某个子问题可能被多次求解,整个算法的时间复杂度呈指数增长。为了避免重复计算子问题,引入优化策略,加入备忘录。

memo = {}
def fib(n):
    if n <= 1:
        return n
    if n in memo:
        return memo[n]
    ret = fib(n-1) + fib(n-2)
    memo[n] = ret;
    return ret

使用加入备忘录把子问题的解保存起来,如果发现某个子问题被求解过了,就直接返回之前计算出的结果。加入备忘录后可以大幅度地提高性能,常常效果已经很不错了。

基于递归的解法,自顶向下求解,要计算几个大问题,先递归地把子问题解决了,然后回过头来计算大问题。我们能不能先把子问题计算出来,然后在基于子问题求解一个稍大点的问题呢。

在斐波那契数列的例子中,即先计算出 fib(0), fib(1), 然后计算出 fib(2) = fib(0) + fib(1),进而计算出 fib(3) = fib(1) + fib(2),以此类推。如此,算法就是基于迭代的,而不是递归了。

这种自底向上的方法,就是动态规划。它先把子问题的结果计算出来,然后基于此计算出依赖于这些子问题的稍大点的子问题。和带备忘录的递归比起来,它的性能较好。因为,像斐波那契数列的例子中,只需要依赖于前两个子问题即可,而带备忘录的递归解法中会记录所有子问题的解,这浪费内存。而且常常迭代比递归稍微高效点。

把 fib 函数改为动态规划,写法如下:

def fib(n):
    if n <= 1:
        return n

    a, b = 0, 1

    for _ in range(1, n):
        a, b = b, a + b
    return b

这里 ab 记录了前两个子问题的解。由于始终只需要依赖前两个子问题,这里复用了 ab

动态规划解题套路

1. 首先分析问题,从中找出子问题,并写出递推关系,以 fib 为例:

f(n) = f(n-1) + f(n-2)
f(0) = 0
f(1) = 1

不同的问题,寻找递推关系的难易不同,通常找到递推关系后,问题就解决了大半。

2. 基于递推关系,实现递归版本的代码,可以写成伪代码,给后面转换为动态规划提供思路。

有前一步得到的递推关系,写出迭代版本应该不难。这一步有助于你深刻把握子问题间的依赖关系。

3. 给递归版本加上备忘录

动态规划常常需要使用额外的空间记录子问题的解,考虑如何添加备忘录,需要考虑子问题的输入与结果的映射关系。有的时候,输入时两个值,输出为一个数,此时就需要二维的数组来记录子问题的解。

4. 把递归版本转为动态规划

分析子问题的依赖关系,理清那些最最基本的问题,以及问题的依赖关系是怎样的。然后优先计算那个最基本的问题,然后依此继续计算其他问题的解。这一步常常比较困难,后面我会举几个例子来详细说明。

案例分析

最长回文子串

问题描述:给定一个字符串 s,找到 s 中最长的回文子串。你可以假设 s 的最大长度为 1000。

分析:

找最长的回文子串,暴力的解法是遍历每一个子字符串,然后判断它是不是回文串,如果是拿其长度与最大长度比较,恰当时更新最大长度。遍历完所有子字符串之后,最大长度就得到了。设字符串长度为 n,则存在 n^2 个子字符串,判断一个字符串是否为回文,时间复杂度为 O(n),因此总时间复杂度为 O(n^3)

稍微做点优化,如果一个字符是回文串,那么只要其左右两边的字符相等,那就可以把该回文串扩大两个字符。利用之前的计算结果,可以避免大量没必要的计算。要计算最长的回文子串,那就是遍历每个子字符串,并得到最大长度。因为判断是否为回文不必要遍历整个子字符串,所有此算法的时间复杂度为 O(n^2)

1. 找递推关系

假设有回文子串 S[left..right],如果 S 左右两边的字符相同,那么 S[left-1..right+1] 也是回文子串。单个字符串肯定是回文子串。因此,如果用 f(left, right) = true/false 表示子串 S[left..right] 是否为回文子串,那么可以得到如下递推关系:

f(left-1, right+1) = f(left, right) && S[left-1] == S[right+1] if right - left > 1

f(left, right) = s[left] == s[right] if right - left <= 1

2. 实现递归解法

实现一个 is_palindrome 函数,判断子字符串是否为回文串。

class Solution {
public:
    string longestPalindrome(string &s) {
        if (s.empty()) {
            return s;
        }
        const int N = s.size();
        int start = 0;
        int max_len = 1;
        for (int i = 0; i < N; i++) {
            for (int j = i + 1; j < N; j++) {
                if (is_palindrome(s, i, j) && max_len < (j - i + 1)) {
                    max_len = j - i + 1;
                    start = i;
                }
            }
        }
        return s.substr(start, max_len);
    }

    bool is_palindrome(const string &s, int left, int right) {
        if (right - left <= 1) {
            return s[left] == s[right];
        }
        if (s[left] == s[right] && is_palindrome(s, left + 1, right - 1)) {
            return true;
        }
        return false;
    }
};

根据前面的分析,递归写法会重复计算子问题,我们需要给它增加一个备忘录。

3. 增加备忘录

使用一个二维数组记录下 s[left..right] 是否为回文串。在未计算时 memo 中均为 -1,计算完成后使用 1 表示是回文,用 0 表示不是回文。

class Solution {
public:
    string longestPalindrome(string &s) {
        if(s.empty()) return s;
        const int N = s.size();
        int start = 0;
        int max_len = 1;
        vector<vector<char>> memo(N, vector<char>(N, -1));

        for(int i=0;i<s.size();i++){
            for(int j=i+1; j<s.size();j++){
                if(is_palindrome(s, i, j, memo) && max_len < (j - i + 1)){
                    max_len = j - i + 1;
                    start = i;
                }
            }
        }
        return s.substr(start, max_len);
    }

    bool is_palindrome(const string& s, int left, int right, vector<vector<char>>& memo){
        if(memo[left][right] != -1){
            return memo[left][right] == 1;
        }

        if(right - left <= 1 && s[left] == s[right]){
            memo[left][right] = 1;
            return true;
        }

        if(s[left] == s[right] && is_palindrome(s, left+1, right-1, memo)){
            memo[left][right] = 1;
            return true;
        }
        memo[left][right] = 0;
        return false;
    }
};

4. 改写成动态规划

f(left, right) 依赖 f(left+1, right-1),很明显,长字符串依赖于短字符串。因此,要想自底向上地计算,可以先判断短字符串是否为回文串,然后把结果存起来。而后再求更大点的子问题。

因为子问题之间存在依赖,确定正确的计算顺序很重要。正确的计算顺序才能保证计算某个子问题时,它所依赖的子问题已经被求解出来了。

对于二维动态规划,通常可以绘制表格,通过表格来观察依赖关系,同时确定计算顺序。

确定横纵坐标后,不难发现,副对角线上的值全为 true, 副对角线上方的是无效区域,因为 left 始终小于 right。而且除了 right-left <= 1 的点(字符相邻)之外,其他点都依赖于右上角的点。

目前,需要基于副对角线上的 true 和字符串 S 的内容,来填写副对角线以下的区域。对于二维矩阵,填表有两种方式:

  1. 一行一行地填
  2. 一列一列地填

有时候需要逆序填写,有时候可以顺序填,这都需要具体问题具体分析。

此处,为了保证在填某个格子时,其右上角已经填写完成,此处只能选择一行一行地填写。因此可以写出如下循环:

for(int right = 0; right < N; right ++){
    for(int left = 0; left < right; left ++){
        // 填表
    }
}

二维的动态规划,都要写一个两层的循环,这个循环怎么写,就是基于表格分析出来的。拿到一个新问题,先画出表格,在表格中确定依赖关系,然后根据依赖关系确定计算顺序,明确了计算顺序,这个二层的循环也就好写了。

下面是动态规划的写法:

class Solution {
public:
    string longestPalindrome(string &s) {
        if(s.empty()){
            return s;
        }
        const int N = s.size();
        int start = 0;
        int max_len = 1;
        vector<vector<bool>> dp(N, vector<bool>(N, false));

        for(int right = 0; right < N; right++){
            dp[right][right] = true;
            for(int left = 0; left < right; left++){
                dp[left][right] = (s[left] == s[right]) &&
                    (right - left <= 1 || dp[left+1][right-1]);

                if(dp[left][right] && (right - left + 1) > max_len){
                    max_len = right - left + 1;
                    start = left;
                }                
            }
        }
        return s.substr(start, max_len);
    }
};

优化

至此,本题算是可以了,不过还是可以继续优化。观察前面的表格中的依赖关系,我们发现可以逐行填充 dp。另外可以轻易看出,在填充某行的时候,只依赖于上面的一行,之前的很多行都没用了。因此,可以考虑压缩 dp 矩阵,使用一维的数组来存储状态。

利用上一层使用过的数组,用第二个元素(蓝色)计算出第一个(棕色),蓝色元素为上一层的信息,棕色元素为本计算结果。如此就可以把 dp 变为一维的。这能大幅节省内存。

class Solution {
public:
    string longestPalindrome(string &s) {
        if(s.empty()) return s;
        const int N = s.size();
        int start = 0;
        int max_len = 1;
        vector<bool> dp(N, false);

        for(int right = 0; right < N; right++){
            dp[right] = true;
            for(int left = 0; left < right; left++){
                dp[left] = (s[left] == s[right]) &&
                    (right - left <= 1 || dp[left+1]);

                if(dp[left] && (right - left + 1) > max_len){
                    max_len = right - left + 1;
                    start = left;
                }                
            }
        }

        return s.substr(start, max_len);
    }
};

这种压缩 dp 矩阵的场景还是挺多的。只要发现某一行/列只依赖于前一行/列,且在覆盖旧值之前能够把新值计算出来,这样的情况就可以考虑压缩。

打家劫舍

题目链接:https://leetcode-cn.com/problems/house-robber

题目描述:

你是一个专业的小偷,计划偷窃沿街的房屋。每间房内都藏有一定的现金,影响你偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统,如果两间相邻的房屋在同一晚上被小偷闯入,系统会自动报警。

给定一个代表每个房屋存放金额的非负整数数组,计算你在不触动警报装置的情况下,能够偷窃到的最高金额。

思路:

设小偷在经过第 n 个房屋后偷到的现金总量为 f(n),根据题意可以得出如下递推公式:

f(n) = max(nums[n] + f(n-2), f(n-1))

f(0) = nums[0]

f(1) = max(nums[0], nums[1])

递归写法

基于以上递推关系,不难得出递归解法如下:

class Solution {
public:
    int rob(vector<int>& nums, int n) {
        if(n == 0){
            return nums[0];
        }
        if(n == 1){
            return max(nums[0], nums[1]);
        }
        return max(nums[n] + rob(nums, n-2), rob(nums, n-1));
    }

    int rob(vector<int>& nums) {
        return rob(nums, nums.size()-1);
    }
};

同理,可以添加备忘录来加速,这里就省略了,我们主要来看看动态规划的写法。

动态规划写法

这里只有一个变量 n,因此是一个一维的动态规划。

class Solution {
public:
    int rob(vector<int>& nums) {
        if(nums.size() == 0){
            return 0;
        }
        if(nums.size() == 1){
            return nums[0];
        }
        vector<int> dp(nums.size(), 0);
        dp[0] = nums[0];
        dp[1] = max(nums[0], nums[1]);

        for(int i=2;i<nums.size();i++){
            dp[i] = max(nums[i] + dp[i-2], dp[i-1]);
        }

        return dp.back();
    }
};

由于 f(n) 只依赖于 f(n-1)f(n-2),因此,其实不需要一个一维的 dp 数组,只需要两个变量即可。使用两个变量记录下 n-1 和 n-2 的状态即可。

class Solution {
public:
    int rob(vector<int>& nums) {
        int dp[2] = {0, 0};
        for(int num: nums){
            int n = max(num + dp[0], dp[1]);
            dp[0] = dp[1];
            dp[1] = n; 
        }
        return dp[1];
    }
};