Trees7 min read

Segment Tree

Answer range queries in \(O(\log n)\) with \(O(\log n)\) point updates
build: \(O(n)\)query: \(O(\log n)\)update: \(O(\log n)\)space: \(O(n)\)

A segment tree is a tree data structure built on top of an array that lets you answer range queries (like "what is the sum of elements from index 3 to 8?") and perform point updates in \(O(\log n)\) time each.

Without a segment tree, a range sum query takes \(O(n)\) time (loop through all elements). A prefix sum array gives \(O(1)\) queries but \(O(n)\) updates. With a segment tree, you preprocess the array in \(O(n)\) time and then answer any query in \(O(\log n)\) — a massive win when you have many queries.

The core idea: divide and conquer

A segment tree stores aggregated values for ranges. The root stores the aggregate for the entire array. Each internal node stores the aggregate for its half of the parent's range. Leaves store individual array elements.

For an array of size 8: [1, 3, 5, 7, 9, 11, 13, 15]

pseudocode
1
2
3
4
[0..7] = 64
[0..3] = 16 [4..7] = 48
[0..1]=4 [2..3]=12 [4..5]=20 [6..7]=28
1 3 5 7 9 11 13 15

To query sum(2, 6): follow the tree, combining only the segments that overlap your range. At most 2 nodes per level are visited → \(O(\log n)\).

Explore segment tree

← → arrow keys to step · click elements to interact

Segment Tree: Build & Range Sum Query

class SegTree:
def __init__(self, arr):
n = len(arr)
self.n = n
self.tree = [0] * (4 * n)
self._build(arr, 0, 0, n - 1)
def _build(self, arr, node, lo, hi):
if lo == hi:
self.tree[node] = arr[lo]; return
mid = (lo + hi) // 2
self._build(arr, 2*node+1, lo, mid)
self._build(arr, 2*node+2, mid+1, hi)
self.tree[node] = self.tree[2*node+1] + self.tree[2*node+2]
def query(self, l, r, node=0, lo=None, hi=None):
if lo is None: lo, hi = 0, self.n - 1
if r < lo or hi < l: return 0
if l <= lo and hi <= r: return self.tree[node]
mid = (lo + hi) // 2
return (self.query(l, r, 2*node+1, lo, mid) +
self.query(l, r, 2*node+2, mid+1, hi))
arr = [1, 3, 5, 7, 9, 11]
st = SegTree(arr)
print(st.query(1, 3)) # 3+5+7 = 15
print(st.query(0, 5)) # total = 36
Output
15
36
Note: Segment trees are not limited to sums. The same structure works for range minimum, range maximum, range GCD, range XOR, and more — just change what each node stores and how they combine.

Building the tree

Build bottom-up: place array values at leaves, then each internal node = combine(left child, right child). For a sum segment tree, each node = sum of its children. This takes \(O(n)\) time.

pseudocode
1
2
3
4
5
6
7
8
def build(node, start, end):
if start == end:
tree[node] = arr[start] # leaf
else:
mid = (start + end) // 2
build(2*node, start, mid) # left child
build(2*node+1, mid+1, end) # right child
tree[node] = tree[2*node] + tree[2*node+1]

Querying a range

To query [l, r]: start at root. If current segment is fully inside [l, r], return its value. If fully outside, return the identity element (0 for sum). Otherwise, recurse into both children and combine results.

pseudocode
1
2
3
4
5
6
def query(node, start, end, l, r):
if r < start or end < l: return 0 # no overlap
if l <= start and end <= r: return tree[node] # full overlap
mid = (start + end) // 2
return query(2*node, start, mid, l, r) + \
query(2*node+1, mid+1, end, l, r)

Point update

To update arr[i] = val: update the leaf, then walk up to the root updating each ancestor. \(O(\log n)\) nodes touched.

Lazy propagation

For range updates (add 5 to all elements from index 3 to 8), a plain segment tree takes \(O(n)\) time. Lazy propagation defers updates — store a pending update at each node and push it down only when the child is needed. This reduces range updates to \(O(\log n)\) as well.

Array vs. pointer implementation

Segment trees are commonly stored in an array of size 4n. Node i has children at 2i and 2i+1, parent at i//2. This avoids pointer overhead and is cache-friendly.

Complexity

BuildFill all 4n nodes bottom-up
Linear\(O(n)\)
Range queryAt most 4 × log n nodes visited
Fast\(O(\log n)\)
Point updateUpdate leaf + all ancestors — path length = height
Fast\(O(\log n)\)
Range update (lazy)Requires lazy propagation; plain update = \(O(n)\)
Fast\(O(\log n)\)
SpaceTree array of size 4n
Linear\(O(n)\)

Segment Tree: Point Update

class SegTree:
def __init__(self, arr):
self.n = len(arr)
self.tree = [0] * (4 * self.n)
self._build(arr, 0, 0, self.n - 1)
def _build(self, arr, node, lo, hi):
if lo == hi: self.tree[node] = arr[lo]; return
mid = (lo + hi) // 2
self._build(arr, 2*node+1, lo, mid)
self._build(arr, 2*node+2, mid+1, hi)
self.tree[node] = self.tree[2*node+1] + self.tree[2*node+2]
def update(self, idx, val, node=0, lo=None, hi=None):
if lo is None: lo, hi = 0, self.n - 1
if lo == hi: self.tree[node] = val; return
mid = (lo + hi) // 2
if idx <= mid: self.update(idx, val, 2*node+1, lo, mid)
else: self.update(idx, val, 2*node+2, mid+1, hi)
self.tree[node] = self.tree[2*node+1] + self.tree[2*node+2]
def query(self, l, r, node=0, lo=None, hi=None):
if lo is None: lo, hi = 0, self.n - 1
if r < lo or hi < l: return 0
if l <= lo and hi <= r: return self.tree[node]
mid = (lo + hi) // 2
return (self.query(l, r, 2*node+1, lo, mid) +
self.query(l, r, 2*node+2, mid+1, hi))
st = SegTree([1, 3, 5, 7, 9, 11])
print(st.query(1, 3)) # 15
st.update(2, 10) # change arr[2] from 5 to 10
print(st.query(1, 3)) # 3+10+7 = 20
Output
15
20
Challenge

Quick check

You have an array of 1,000,000 elements and will perform 1,000,000 range sum queries. What is the total time with a segment tree vs. naive approach?

Continue reading

Fenwick Tree (BIT)
Prefix sums in \(O(\log n)\) with elegant bit manipulation — simpler than a segment tree
Prefix Sums
Pre-calculate a running total so you can sum any subarray in instant \(O(1)\) time