首页 > 算法 > 使用线段树求区间和
2023
02-21

使用线段树求区间和

使用线段树求区间和 - 第1张  | Weiguang的博客
线段树(Segment Tree)是一种二叉搜索树,与区间树相似,他将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶子节点。使用线段树可以快速的查找某一个节点在若干条线段中出现的次数,时间复杂度为 O(logN)。而未优化的空间复杂度为 2N ,实际应用时一般还要开 4N 的数组以免越界,因此有时需要离散化让空间压缩。 —— 百度百科*

线段树可以使用一个堆来表示,存储在长度为 2N 的数组中(数组的下标从 1 开始,不使用 0 下标),对于一个非叶子节点 i ,其左儿子的索引为 2 * i,即 i << 1 ,其右儿子的索引为 2 * i + 1,即 i << 1 | 1。

一个线段树节点的定义如下,其中,l 和 r 分别表示节点对应区间的左右端点,v 则表示区间和。当 l 等于 r 时,该节点是一个叶子节点,v 存储的是原数组中元素的值。

  class Node {
    int l, r, v;

    public Node(int l, int r) {
      this.l = l;
      this.r = r;
    }

线段树可以用堆来表示,存储在数组中 :

Node[] st;

树状数组涉及的操作有两个,复杂度均为 O(log⁡n) :

  • void update(int node, int index, int val):含义为在以node为根节点的线段树的 index节点增加 val(注意位置下标从 1 开始);

  • int query(int node, int l, int r):含义为查询以node为根节点的线段树在 [l, r] 区间的和为多少(配合容斥原理,可实现任意区间查询)。

    更新操作 update

更新时,如果当前节点就是要更新的节点,则更新后直接返回。否则,判断需要更新的节点在当前节点的左子树还是右子树,对其更新后使用pushup (int node) 方法更新当前节点。

    private void update(int node, int index, int val) {
      if (st[node].l == index && st[node].r == index) {
        st[node].v += val;
        return;
      }
      int mid = (st[node].l + st[node].r) >> 1;
      if (index <= mid) update(node << 1, index, val);
      else update(node << 1 | 1, index, val);
      pushup(node);
    }

    private void pushup(int node) {
      st[node].v = st[node << 1].v + st[node << 1 | 1].v;
    }

构造线段树

保存线段树的数组下标从 1 开始,可以西安使用 buildTree() 方法初始化 Node[] 数组,然后调用 n 次 update() 方法完成初始化。

    Node[] st;
    
    public SegmentTree(int[] nums) {
      int n = nums.length;
      st = new Node[n * 4];
      buildTree(1, 1, n);
      for (int i = 0; i < n; i++) {
        update(1, i + 1, nums[i]);
      }
    }

    private void buildTree(int node, int l, int r) {
      st[node] = new Node(l, r);
      if (l == r) return;
      int mid = (l + r) >> 1;
      buildTree(node << 1, l, mid);
      buildTree(node << 1 | 1, mid + 1, r);
    }

查询区间和 query

若当前节点表示的区间和被查询范围 [l : r] 完全覆盖,则直接返回当前的区间和。否则,查询的区间有三种可能性:

  1. 全部在左子树上
  2. 全部在右子树上
  3. 左右子树各有一部分
    private int query(int node, int l, int r) {
      if (st[node].l >= l && st[node].r <= r) return st[node].v;
      int mid = (st[node].l + st[node].r) >> 1; // 其实不需要加括号,右移的优先级小于加法
      if (r <= mid) return query(node << 1, l, r);
      if (l > mid) return query(node << 1 | 1, l, r);
      return query(node << 1, l, mid) + query(node << 1 | 1, mid + 1, r);
    }

完整的线段树代码

class SegmentTree {

    Node[] st;
    
    public SegmentTree(int[] nums) {
      int n = nums.length;
      st = new Node[n * 4];
      buildTree(1, 1, n);
      for (int i = 0; i < n; i++) {
        update(1, i + 1, nums[i]);
      }
    }

    private void buildTree(int node, int l, int r) {
      st[node] = new Node(l, r);
      if (l == r) return;
      int mid = (l + r) >> 1;
      buildTree(node << 1, l, mid);
      buildTree(node << 1 | 1, mid + 1, r);
    }

    private void update(int node, int index, int val) {
      if (st[node].l == index && st[node].r == index) {
        st[node].v += val;
        return;
      }
      int mid = (st[node].l + st[node].r) >> 1;
      if (index <= mid) update(node << 1, index, val);
      else update(node << 1 | 1, index, val);
      pushup(node);
    }

    private void pushup(int node) {
      st[node].v = st[node << 1].v + st[node << 1 | 1].v;
    }

    private int query(int node, int l, int r) {
      if (st[node].l >= l && st[node].r <= r) return st[node].v;
      int mid = (st[node].l + st[node].r) >> 1; // 其实不需要加括号,右移的优先级小于加法
      if (r <= mid) return query(node << 1, l, r);
      if (l > mid) return query(node << 1 | 1, l, r);
      return query(node << 1, l, mid) + query(node << 1 | 1, mid + 1, r);
    }
  }

线段树在计算区间和上的应用

以LeetCode 307. 区域和检索 - 数组可修改为例。

class NumArray {

  SegmentTree st;
  int[] nums;
  
  public NumArray(int[] nums) {
    this.nums = nums;
    this.st = new SegmentTree(nums);
  }

  public void update(int index, int val) {
    int v = val - nums[index];
    nums[index] = val;
    st.update(1, index + 1, v);
  }

  public int sumRange(int left, int right) {
    return st.query(1, left + 1, right + 1);
  }

  private class Node {
    int l, r, v;

    public Node(int l, int r) {
      this.l = l;
      this.r = r;
    }
  }

  private class SegmentTree {
  
    Node[] st;
    
    public SegmentTree(int[] nums) {
      int n = nums.length;
      st = new Node[n * 4];
      buildTree(1, 1, n);
      for (int i = 0; i < n; i++) {
        update(1, i + 1, nums[i]);
      }
    }

    private void buildTree(int node, int l, int r) {
      st[node] = new Node(l, r);
      if (l == r) return;
      int mid = (l + r) >> 1;
      buildTree(node << 1, l, mid);
      buildTree(node << 1 | 1, mid + 1, r);
    }

    private void update(int node, int index, int val) {
      if (st[node].l == index && st[node].r == index) {
        st[node].v += val;
        return;
      }
      int mid = (st[node].l + st[node].r) >> 1;
      if (index <= mid) update(node << 1, index, val);
      else update(node << 1 | 1, index, val);
      pushup(node);
    }

    private void pushup(int node) {
      st[node].v = st[node << 1].v + st[node << 1 | 1].v;
    }

    private int query(int node, int l, int r) {
      if (st[node].l >= l && st[node].r <= r) return st[node].v;
      int mid = (st[node].l + st[node].r) >> 1; // 其实不需要加括号,右移的优先级小于加法
      if (r <= mid) return query(node << 1, l, r);
      if (l > mid) return query(node << 1 | 1, l, r);
      return query(node << 1, l, mid) + query(node << 1 | 1, mid + 1, r);
    }
  }
}

/**
 * Your NumArray object will be instantiated and called as such:
 * NumArray obj = new NumArray(nums);
 * obj.update(index,val);
 * int param_2 = obj.sumRange(left,right);
 */

参考资料:
【区间求和の解决方案】307. 区域和检索 - 数组可修改 :「树状数组」&「线段树」
线段树&树状数组

最后编辑:
作者:lwg0452
这个作者貌似有点懒,什么都没有留下。
捐 赠如果您觉得这篇文章有用处,请支持作者!鼓励作者写出更好更多的文章!

留下一个回复

你的email不会被公开。