Skip to content

Latest commit

 

History

History
390 lines (293 loc) · 9.77 KB

线段树.md

File metadata and controls

390 lines (293 loc) · 9.77 KB

线段树总结

区间和检索-数组可修改

LeetCode中文

LeetCode英文

给定一个整数数组 nums,求出数组从索引 ij (i ≤ j) 范围内元素的总和,包含 i, j 两点。

update(i, val) 函数可以通过将下标为 i 的数值更新为 val,从而对数列进行修改。

示例:

Given nums = [1, 3, 5]

sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8

说明:

  1. 数组仅可以在 update 函数下进行修改。
  2. 你可以假设 update 函数与 sumRange 函数的调用次数是均匀分布的。

解答

参考 线段树课程-九章算法

  • 时间复杂度:build为O(n),updategetSum为O(log n)
  • 空间复杂度:O(n)
class segTree
{
public:
    segTree(int l,int r)
    {
        start = l;
        end = r;
        sum = 0;
        left = right = nullptr;
    }
    
    int start,end,sum;
    segTree* left;
    segTree* right;
};

void update_node(segTree*& root,int i,int val)
{
    if(root->start == root->end && root->start == i)
    {
        root->sum = val;
        return;
    }
    
    int m = root->start + (root->end - root->start)/2;
    if(i <= m && i >= root->start)
    {
        update_node(root->left,i,val);
    }
    else if(i >= m+1 && i <= root->end)
    {
        update_node(root->right,i,val);
    }
    
    root->sum = root->left->sum + root->right->sum;
}

segTree* build(int l,int r,vector<int>& vec)
{
    if(l == r)
    {
        segTree* tmp = new segTree(l,r);
        tmp->sum = vec[l];
        return tmp;
    }
    
    segTree* res = new segTree(l,r);
    int m = l + (r - l) / 2;
    res->left = build(l,m,vec);
    res->right = build(m + 1,r,vec);
    
    res->sum = res->left->sum + res->right->sum;
    return res;
}

int getSum(int l,int r,segTree* root)
{
    if(l == root->start && r == root->end)
        return root->sum;
    
    int m = root->start + (root->end - root->start) / 2;
    int sum_l = 0;
    int sum_r = 0;
    if(l <= m)
    {
        if(r<=m)
            sum_l = getSum(l,r,root->left);
        else
            sum_l = getSum(l,m,root->left);
    }
    
    if(r > m)
    {
        if(l > m)
            sum_r = getSum(l,r,root->right);
        else
            sum_r = getSum(m + 1,r,root->right);
    }
        
    return (sum_l + sum_r);
}

class NumArray {
public:
    NumArray(vector<int> nums) {
        if(nums.empty())
            return;
        
        vec = nums;
        root = build(0,nums.size() - 1,nums);
    }
    
    void update(int i, int val) {
        update_node(root,i,val);
    }
    
    int sumRange(int i, int j) {
        return getSum(i,j,root);
    }
    
private:
    segTree* root;
    vector<int> vec;
};

/**
 * Your NumArray object will be instantiated and called as such:
 * NumArray obj = new NumArray(nums);
 * obj.update(i,val);
 * int param_2 = obj.sumRange(i,j);
 */

计算右侧小于当前元素的个数

LeetCode中文

LeetCode英文

给定一个整数数组 nums,按要求返回一个新数组 counts

数组 counts 有该性质: counts[i] 的值是 nums[i] 右侧小于 nums[i] 的元素的数量。

示例:

输入: [5,2,6,1]
输出: [2,1,1,0] 
解释:
5 的右侧有 2 个更小的元素 (2 和 1).
2 的右侧仅有 1 个更小的元素 (1).
6 的右侧有 1 个更小的元素 (1).
1 的右侧有 0 个更小的元素.

解答

此问题可以转化为线段树的区间问题,对于nums[i],求小于它的元素个数相当于是求在区间0~nums[i]-1中的元素个数。开辟数组count,找出数组nums的最大值num_max和最小值num_mincount的大小为num_max-num_min+1count[i]表示数i+num_min到目前为止的数量,将count建立成线段树。然后从右往左遍历nums,遇到nums[i]时,去查询count0~nums[i]-1区间的总和,即nums[i]右边小于它的元素个数。每次查询完后,将查询结果添加到结果res中,然后将nums[i]count对应的计数值加一(即count[nums[i] - num_min]加一),同时更新到线段树。遍历结束后,将结果数组res反转即为所求。

具体举例nums = [5,2,6,1],则num_min为1,num_max为6,count大小为6,count = [0,0,0,0,0,0],我们基于count建立线段树。然后从右往左遍历nums

  1. nums[3] = 1,查询区间[0,0]的和,得到0,加入res,然后在对应下标nums[i]-num_min即在0的位置加1。此时count[1,0,0,0,0,0]res[0]

  2. nums[2] = 6,查询区间[0,5]的和,得到1,加入res,然后在对应下标nums[i]-num_min即在5的位置加1。此时count[1,0,0,0,0,1]res[0,1]

  3. nums[1] = 2,查询区间[0,1]的和,得到1,加入res,然后在对应下标nums[i]-num_min即在1的位置加1。此时count[1,1,0,0,0,1]res[0,1,1]

  4. nums[0] = 5,查询区间[0,4]的和,得到2,加入res,然后在对应下标nums[i]-num_min即在1的位置加1。此时count[1,1,0,0,1,1]res[0,1,1,2]

最后,将res反转得到[2,1,1,0]即为结果。

  • 时间复杂度:O(n log n)
  • 空间复杂度:O(n)
class segNode
{
public:
    segNode(int l,int r)
    {
        start = l;
        end = r;
        sum = 0;
        left = right = nullptr;
    }

    segNode* left,*right;
    int start,end,sum;
};


segNode* build(int l,int r,vector<int>& vec)
{
    if(l == r)
    {
        segNode* tmp = new segNode(l,r);
        tmp->sum = vec[l];
        return tmp;
    }
    
    int m = l + (r - l)/2;
    segNode* res = new segNode(l,r);
    res->left = build(l,m,vec);
    res->right = build(m+1,r,vec);
    res->sum = res->left->sum + res->right->sum;
    
    return res;
}

int query(int l,int r,segNode* root)
{
    if(l > r)
        return 0;
    
    if(l == root->start && r == root->end)
        return root->sum;
    
    int m = root->start + (root->end - root->start)/2;
    int sum_l = 0;
    int sum_r = 0;
    if(l <= m)
    {
        if(r<=m)
            sum_l = query(l,r,root->left);
        else
            sum_l = query(l,m,root->left);
    }
    
    if(r > m)
    {
        if(l > m)
            sum_r = query(l,r,root->right);
        else
            sum_r = query(m+1,r,root->right);
    }
    
    return (sum_l + sum_r);
}

void updateNode(int i,int val,segNode* root)
{
    if(i == root->start && i == root->end)
    {
        root->sum = val;
        return;
    }
    
    int m = root->start + (root->end - root->start)/2;
    if(i <= m)
    {
        updateNode(i,val,root->left);
    }
    else{
        updateNode(i,val,root->right);
    }
    
    root->sum = root->left->sum + root->right->sum;
    return;
}

class Solution {
public:
    vector<int> countSmaller(vector<int>& nums) {
        if(nums.empty())
            return vector<int>();
        
        int num_max = INT_MIN;
        int num_min = INT_MAX;
        vector<int> res;
        for(auto a : nums)
        {
            num_max = max(a,num_max);
            num_min = min(a,num_min);
        }
        
        vector<int> count(num_max - num_min + 1,0);
        root = build(0,count.size()-1,count);
        
        for(int i=nums.size()-1;i>=0;i--)
        {
            int index = nums[i] - num_min;
            res.push_back(query(0,index-1,root));
            
            
            int old_val = query(index,index,root);
            updateNode(index,old_val+1,root);
        }
        
        reverse(res.begin(),res.end());
        
        return res;
    }
 
private:
    segNode* root;
};

字符串中的加粗单词

LeetCode中文

LeetCode英文

给定一个关键词集合 words 和一个字符串 S,将所有 S 中出现的关键词加粗。所有在标签 <b></b> 中的字母都会加粗。

返回的字符串需要使用尽可能少的标签,当然标签应形成有效的组合。

例如,给定words = ["ab", "bc"]S = "aabcd",需要返回 "a<b>abc</b>d"。注意返回 "a<b>a<b>b</b>c</b>d" 会使用更多的标签,因此是错误的。

  • words 长度的范围为 [0, 50]
  • words[i] 长度的范围为 [1, 10]
  • S 长度的范围为 [0, 500]
  • 所有 words[i]S 中的字符都为小写字母。

解答

方法1:字典树

C++代码

方法2:哈希

Python代码

class Solution:
    def boldWords(self, words: List[str], S: str) -> str:
        bold_tag = [False] * len(S)
        for w in words:
            b = 0
            l = len(w)
            while b < len(S):
                if w not in S[b:]:
                    break
                idx = S.index(w, b)
                for i in range(idx, idx + l):
                    bold_tag[i] = True
                b = idx + 1
        res = ''
        i = 0
        while i < len(S):
            if not bold_tag[i]:
                res += S[i]
                i += 1
            else:
                res += '<b>' + S[i]
                i += 1
                while i < len(S) and bold_tag[i]:
                    res += S[i]
                    i += 1
                res += '</b>'
        return res