-
Notifications
You must be signed in to change notification settings - Fork 522
/
PersistentTree.java
52 lines (45 loc) · 1.6 KB
/
PersistentTree.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
package structures;
// https://en.wikipedia.org/wiki/Persistent_data_structure
public class PersistentTree {
public static class Node {
Node left, right;
int sum;
Node(int value) {
sum = value;
}
Node(Node left, Node right) {
this.left = left;
this.right = right;
sum = left.sum + right.sum;
}
}
public static Node build(int left, int right) {
if (left == right)
return new Node(0);
int mid = (left + right) >> 1;
return new Node(build(left, mid), build(mid + 1, right));
}
public static int sum(int from, int to, Node root, int left, int right) {
if (from > right || left > to)
return 0;
if (from <= left && right <= to)
return root.sum;
int mid = (left + right) >> 1;
return sum(from, to, root.left, left, mid) + sum(from, to, root.right, mid + 1, right);
}
public static Node set(int pos, int value, Node root, int left, int right) {
if (left == right)
return new Node(value);
int mid = (left + right) >> 1;
return pos <= mid ? new Node(set(pos, value, root.left, left, mid), root.right)
: new Node(root.left, set(pos, value, root.right, mid + 1, right));
}
// Usage example
public static void main(String[] args) {
int n = 10;
Node t1 = build(0, n - 1);
Node t2 = set(0, 1, t1, 0, n - 1);
System.out.println(0 == sum(0, 9, t1, 0, n - 1));
System.out.println(1 == sum(0, 9, t2, 0, n - 1));
}
}