1
+ # calculate the sum of the range
2
+ # update values in the array
3
+
4
+ class SegmentTree ():
5
+ def __init__ (self , sum : int , L : int , R : int ):
6
+ self .L = L
7
+ self .R = R
8
+ self .sum = sum
9
+ self .left = None
10
+ self .right = None
11
+ pass
12
+ @staticmethod
13
+ def build (arr : list [int ], L : int , R : int ) -> "SegmentTree" :
14
+ if L == R :
15
+ return SegmentTree (arr [L ], L , R )
16
+ M = (L + R ) // 2
17
+ root = SegmentTree (0 , L , R )
18
+ root .left = SegmentTree .build (arr , L , M )
19
+ root .right = SegmentTree .build (arr , M + 1 , R )
20
+ root .sum = root .left .sum + root .right .sum
21
+ return root
22
+
23
+ # runtime O(log(n))
24
+ def update (self , index : int , val : int ) -> None :
25
+ if self .L == self .R :
26
+ self .sum = val
27
+ return
28
+ M = ( self .L + self .R ) // 2
29
+ if M >= index :
30
+ self .left .update (index , val )
31
+ else :
32
+ self .right .update (index , val )
33
+ self .sum = self .left .sum + self .right .sum
34
+ # runtime O(log(n))
35
+ def queryRange (self , L : int , R : int ) -> int :
36
+ if self .L == L and self .R == R : return self .sum
37
+ M = (self .L + self .R ) // 2
38
+ if M >= R :
39
+ return self .left .queryRange (L , R )
40
+ elif M < L :
41
+ return self .right .queryRange (L , R )
42
+ else :
43
+ return (self .left .queryRange (L , M ) +
44
+ self .right .queryRange (M + 1 , R ))
45
+
46
+
47
+
48
+ t = SegmentTree .build ([1 ,2 ,3 ,4 ,5 ],0 ,4 )
49
+ print (f'root sum: { t .sum } ' )
50
+ print (f'left [{ t .left .L } , { t .left .R } ] sum: { t .left .sum } ' )
51
+ print (f'right [{ t .right .L } , { t .right .R } ] sum: { t .right .sum } ' )
52
+ print (f'query (0,3): { t .queryRange (0 ,3 )} ' ) # 10
53
+ print (f'query (1,3): { t .queryRange (1 ,3 )} ' ) # 9
54
+ print (f'query (2,3): { t .queryRange (2 ,3 )} ' ) # 7
55
+ print (f'query (3,4): { t .queryRange (3 ,4 )} ' ) # 9
56
+ t .update (3 ,1 )
57
+ print (f'query (3,4): { t .queryRange (3 ,4 )} ' ) # 6
0 commit comments