Skip to content

Commit 0811cd0

Browse files
[ENHANCEMENT] Add Wavelet Tree Data Structure (#7414)
* Implement Wavelet Tree with rank and kthSmallest methods * Implement Wavelet Tree with rank and kthSmallest methods * Fix checkstyle multiple variable declarations violation --------- Co-authored-by: Deniz Altunkapan <deniz.altunkapan@outlook.com>
1 parent e7f8979 commit 0811cd0

2 files changed

Lines changed: 352 additions & 0 deletions

File tree

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
package com.thealgorithms.datastructures.trees;
2+
3+
import java.util.ArrayList;
4+
import java.util.List;
5+
6+
/**
7+
* A Wavelet Tree is a highly efficient data structure used to store sequences
8+
* and answer queries like rank, select, and quantile in O(log(max_val - min_val)) time.
9+
* This structure is particularly useful in competitive programming and text compression.
10+
*/
11+
public class WaveletTree {
12+
13+
private class Node {
14+
int low;
15+
int high;
16+
Node left;
17+
Node right;
18+
List<Integer> leftCount; // Prefix sums of elements going to the left child
19+
20+
/**
21+
* Recursively constructs the tree nodes by partitioning the array.
22+
*
23+
* @param arr the subarray for the current node
24+
* @param low the minimum possible value in the current node
25+
* @param high the maximum possible value in the current node
26+
*/
27+
Node(int[] arr, int low, int high) {
28+
this.low = low;
29+
this.high = high;
30+
31+
if (low == high) {
32+
return;
33+
}
34+
35+
int mid = low + (high - low) / 2;
36+
leftCount = new ArrayList<>(arr.length + 1);
37+
leftCount.add(0);
38+
39+
List<Integer> leftArr = new ArrayList<>();
40+
List<Integer> rightArr = new ArrayList<>();
41+
42+
for (int x : arr) {
43+
if (x <= mid) {
44+
leftArr.add(x);
45+
leftCount.add(leftCount.get(leftCount.size() - 1) + 1);
46+
} else {
47+
rightArr.add(x);
48+
leftCount.add(leftCount.get(leftCount.size() - 1));
49+
}
50+
}
51+
52+
if (!leftArr.isEmpty()) {
53+
this.left = new Node(leftArr.stream().mapToInt(i -> i).toArray(), low, mid);
54+
}
55+
if (!rightArr.isEmpty()) {
56+
this.right = new Node(rightArr.stream().mapToInt(i -> i).toArray(), mid + 1, high);
57+
}
58+
}
59+
}
60+
61+
private Node root;
62+
private final int n;
63+
64+
/**
65+
* Constructs a Wavelet Tree from the given array.
66+
* The min and max values are determined dynamically from the array.
67+
*
68+
* @param arr the input array
69+
*/
70+
public WaveletTree(int[] arr) {
71+
if (arr == null || arr.length == 0) {
72+
this.n = 0;
73+
return;
74+
}
75+
this.n = arr.length;
76+
int min = arr[0];
77+
int max = arr[0];
78+
for (int x : arr) {
79+
if (x < min) {
80+
min = x;
81+
}
82+
if (x > max) {
83+
max = x;
84+
}
85+
}
86+
root = new Node(arr, min, max);
87+
}
88+
89+
/**
90+
* Constructs a Wavelet Tree from the given array with specific min and max values.
91+
*
92+
* @param arr the input array
93+
* @param minValue the minimum possible value
94+
* @param maxValue the maximum possible value
95+
*/
96+
public WaveletTree(int[] arr, int minValue, int maxValue) {
97+
if (arr == null || arr.length == 0) {
98+
this.n = 0;
99+
return;
100+
}
101+
this.n = arr.length;
102+
root = new Node(arr, minValue, maxValue);
103+
}
104+
105+
/**
106+
* How many times does the number x appear in the array from index 0 to i (inclusive)?
107+
*
108+
* @param x the number to search for
109+
* @param i the end index (0-based, inclusive)
110+
* @return the number of occurrences of x in arr[0...i]
111+
*/
112+
public int rank(int x, int i) {
113+
if (root == null || x < root.low || x > root.high || i < 0) {
114+
return 0;
115+
}
116+
// If i is out of bounds, cap it at n - 1
117+
int endIdx = Math.min(i, n - 1);
118+
return rank(root, x, endIdx + 1);
119+
}
120+
121+
private int rank(Node node, int x, int count) {
122+
if (node == null || count == 0) {
123+
return 0;
124+
}
125+
if (node.low == node.high) {
126+
return count;
127+
}
128+
int mid = node.low + (node.high - node.low) / 2;
129+
int leftC = node.leftCount.get(count);
130+
if (x <= mid) {
131+
return rank(node.left, x, leftC);
132+
} else {
133+
return rank(node.right, x, count - leftC);
134+
}
135+
}
136+
137+
/**
138+
* What is the 0-based index of the k-th occurrence of the number x in the array?
139+
*
140+
* @param x the number to search for
141+
* @param k the occurrence count (1-based)
142+
* @return the 0-based index in the original array, or -1 if x occurs less than k times
143+
*/
144+
public int select(int x, int k) {
145+
if (root == null || x < root.low || x > root.high || k <= 0) {
146+
return -1;
147+
}
148+
if (rank(x, n - 1) < k) {
149+
return -1;
150+
}
151+
return select(root, x, k);
152+
}
153+
154+
private int select(Node node, int x, int k) {
155+
if (node.low == node.high) {
156+
return k - 1; // 0-based index within the imaginary array at the leaf
157+
}
158+
int mid = node.low + (node.high - node.low) / 2;
159+
if (x <= mid) {
160+
int posInLeft = select(node.left, x, k);
161+
return binarySearchLeft(node.leftCount, posInLeft + 1);
162+
} else {
163+
int posInRight = select(node.right, x, k);
164+
return binarySearchRight(node.leftCount, posInRight + 1);
165+
}
166+
}
167+
168+
private int binarySearchLeft(List<Integer> prefixSums, int k) {
169+
int l = 1;
170+
int r = prefixSums.size() - 1;
171+
int ans = -1;
172+
while (l <= r) {
173+
int mid = l + (r - l) / 2;
174+
if (prefixSums.get(mid) >= k) {
175+
ans = mid;
176+
r = mid - 1;
177+
} else {
178+
l = mid + 1;
179+
}
180+
}
181+
return ans == -1 ? -1 : ans - 1; // Convert to 0-based index
182+
}
183+
184+
private int binarySearchRight(List<Integer> prefixSums, int k) {
185+
int l = 1;
186+
int r = prefixSums.size() - 1;
187+
int ans = -1;
188+
while (l <= r) {
189+
int mid = l + (r - l) / 2;
190+
if (mid - prefixSums.get(mid) >= k) {
191+
ans = mid;
192+
r = mid - 1;
193+
} else {
194+
l = mid + 1;
195+
}
196+
}
197+
return ans == -1 ? -1 : ans - 1; // Convert to 0-based index
198+
}
199+
200+
/**
201+
* If you sort the subarray from index left to right, what would be the k-th smallest element?
202+
* This query is also commonly known as the quantile query.
203+
*
204+
* @param left the start index of the subarray (0-based, inclusive)
205+
* @param right the end index of the subarray (0-based, inclusive)
206+
* @param k the rank of the smallest element (1-based, e.g., k=1 is the minimum)
207+
* @return the k-th smallest element in the subarray, or -1 if invalid parameters
208+
*/
209+
public int kthSmallest(int left, int right, int k) {
210+
if (root == null || left > right || left < 0 || k < 1 || k > right - left + 1) {
211+
return -1;
212+
}
213+
return kthSmallest(root, left, right, k);
214+
}
215+
216+
private int kthSmallest(Node node, int left, int right, int k) {
217+
if (node.low == node.high) {
218+
return node.low;
219+
}
220+
221+
int countLeftInLMinus1 = (left == 0) ? 0 : node.leftCount.get(left);
222+
int countLeftInR = node.leftCount.get(right + 1);
223+
int elementsToLeft = countLeftInR - countLeftInLMinus1;
224+
225+
if (k <= elementsToLeft) {
226+
int newL = countLeftInLMinus1;
227+
int newR = countLeftInR - 1;
228+
return kthSmallest(node.left, newL, newR, k);
229+
} else {
230+
int newL = left - countLeftInLMinus1;
231+
int newR = right - countLeftInR;
232+
return kthSmallest(node.right, newL, newR, k - elementsToLeft);
233+
}
234+
}
235+
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package com.thealgorithms.datastructures.trees;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
5+
import org.junit.jupiter.api.Test;
6+
7+
public class WaveletTreeTest {
8+
9+
@Test
10+
public void testRank() {
11+
int[] arr = {5, 1, 2, 5, 1};
12+
WaveletTree wt = new WaveletTree(arr);
13+
14+
// x = 1
15+
assertEquals(1, wt.rank(1, 1)); // In [5, 1], '1' appears 1 time
16+
assertEquals(2, wt.rank(1, 4)); // In [5, 1, 2, 5, 1], '1' appears 2 times
17+
assertEquals(0, wt.rank(1, 0)); // In [5], '1' appears 0 times
18+
19+
// x = 5
20+
assertEquals(1, wt.rank(5, 0)); // In [5], '5' appears 1 time
21+
assertEquals(1, wt.rank(5, 2)); // In [5, 1, 2], '5' appears 1 time
22+
assertEquals(2, wt.rank(5, 4)); // In [5, 1, 2, 5, 1], '5' appears 2 times
23+
24+
// Out of bounds / invalid value
25+
assertEquals(0, wt.rank(10, 4)); // '10' is not in the array
26+
assertEquals(0, wt.rank(5, -1)); // Invalid end index
27+
}
28+
29+
@Test
30+
public void testSelect() {
31+
int[] arr = {5, 1, 2, 5, 1};
32+
WaveletTree wt = new WaveletTree(arr);
33+
34+
assertEquals(1, wt.select(1, 1)); // 1st '1' is at index 1
35+
assertEquals(4, wt.select(1, 2)); // 2nd '1' is at index 4
36+
37+
assertEquals(0, wt.select(5, 1)); // 1st '5' is at index 0
38+
assertEquals(3, wt.select(5, 2)); // 2nd '5' is at index 3
39+
40+
assertEquals(2, wt.select(2, 1)); // 1st '2' is at index 2
41+
42+
assertEquals(-1, wt.select(5, 3)); // 3rd '5' doesn't exist
43+
assertEquals(-1, wt.select(10, 1)); // '10' doesn't exist
44+
assertEquals(-1, wt.select(5, 0)); // invalid k
45+
}
46+
47+
@Test
48+
public void testKthSmallest() {
49+
int[] arr = {5, 1, 2, 5, 1};
50+
WaveletTree wt = new WaveletTree(arr);
51+
52+
// Array: [5, 1, 2, 5, 1] -> Sorted: [1, 1, 2, 5, 5]
53+
assertEquals(1, wt.kthSmallest(0, 4, 1)); // 1st smallest in [5, 1, 2, 5, 1] is 1
54+
assertEquals(1, wt.kthSmallest(0, 4, 2)); // 2nd smallest in [5, 1, 2, 5, 1] is 1
55+
assertEquals(2, wt.kthSmallest(0, 4, 3)); // 3rd smallest in [5, 1, 2, 5, 1] is 2
56+
assertEquals(5, wt.kthSmallest(0, 4, 4)); // 4th smallest in [5, 1, 2, 5, 1] is 5
57+
assertEquals(5, wt.kthSmallest(0, 4, 5)); // 5th smallest in [5, 1, 2, 5, 1] is 5
58+
59+
// Subarray: arr[1..3] = [1, 2, 5] -> Sorted: [1, 2, 5]
60+
assertEquals(1, wt.kthSmallest(1, 3, 1)); // 1st smallest in [1, 2, 5] is 1
61+
assertEquals(2, wt.kthSmallest(1, 3, 2)); // 2nd smallest in [1, 2, 5] is 2
62+
assertEquals(5, wt.kthSmallest(1, 3, 3)); // 3rd smallest in [1, 2, 5] is 5
63+
64+
// Invalid ranges / arguments
65+
assertEquals(-1, wt.kthSmallest(4, 2, 1)); // Invalid range (left > right)
66+
assertEquals(-1, wt.kthSmallest(0, 4, 10)); // k > range length
67+
assertEquals(-1, wt.kthSmallest(0, 4, 0)); // k < 1
68+
}
69+
70+
@Test
71+
public void testEmptyAndSingleElementArray() {
72+
WaveletTree wtEmpty = new WaveletTree(new int[] {});
73+
assertEquals(0, wtEmpty.rank(1, 0));
74+
assertEquals(-1, wtEmpty.select(1, 1));
75+
assertEquals(-1, wtEmpty.kthSmallest(0, 0, 1));
76+
77+
WaveletTree wtSingle = new WaveletTree(new int[] {42});
78+
assertEquals(1, wtSingle.rank(42, 0));
79+
assertEquals(0, wtSingle.rank(42, -1));
80+
assertEquals(0, wtSingle.select(42, 1));
81+
assertEquals(-1, wtSingle.select(42, 2));
82+
assertEquals(42, wtSingle.kthSmallest(0, 0, 1));
83+
}
84+
85+
@Test
86+
public void testNullArrayAndCustomBounds() {
87+
WaveletTree wtNull = new WaveletTree(null);
88+
assertEquals(0, wtNull.rank(1, 0));
89+
90+
WaveletTree wtNullCustom = new WaveletTree(null, 1, 5);
91+
assertEquals(-1, wtNullCustom.select(1, 1));
92+
93+
int[] arr = {5, 1, 2, 5, 1};
94+
WaveletTree wtCustom = new WaveletTree(arr, 1, 10);
95+
assertEquals(2, wtCustom.rank(5, 4));
96+
assertEquals(0, wtCustom.rank(4, 4)); // Query an element inside bounds but not in array
97+
assertEquals(0, wtCustom.rank(10, 4)); // Query upper bound
98+
}
99+
100+
@Test
101+
public void testNegativeValues() {
102+
int[] arr = {-5, 10, -2, 0, -5};
103+
WaveletTree wt = new WaveletTree(arr);
104+
105+
assertEquals(2, wt.rank(-5, 4));
106+
assertEquals(1, wt.rank(0, 3));
107+
108+
assertEquals(0, wt.select(-5, 1));
109+
assertEquals(4, wt.select(-5, 2));
110+
assertEquals(3, wt.select(0, 1));
111+
112+
// Sorted: [-5, -5, -2, 0, 10]
113+
assertEquals(-5, wt.kthSmallest(0, 4, 1));
114+
assertEquals(-2, wt.kthSmallest(0, 4, 3));
115+
assertEquals(10, wt.kthSmallest(0, 4, 5));
116+
}
117+
}

0 commit comments

Comments
 (0)