Segment tree with lazy propagation for multiple of 3

1.5k views Asked by At

Abridged problem: You're given an array of n elements, initially they are all 0.

You will receive two types of query: 0 index1 index2, in this case you have to increase by one all elements in range index1 index2(included).

Second type: 1 index1 index2, in this case you have to print a number rapresenting how many elements between index1 and index2(included) are divisible by 3.

Of course, as n is very large(10^6) the good approach is to use segment tree to store intervals, and also to use lazy propagation to update the tree in log n.

But I actually really don't know how to apply lazy propagation here, because you have to keep into account three possible states for every number( may be 3k,3k+1,3k+2), and not just two as the flipping coins problem.

If I put a flag on some interval that is included in the interval of my query, I have to update it looking at the original array and at its value, but when I have to update the son of this interval I have to do the same again and this is a wasteful of time....

Any better idea? I search on the net but found nothing ...

EDIT: I follow your suggestions and I code this( C++), and works for some base cases, but when I submit it I get just 10/100 points, what is wrong with it ? (I know it's a bit long and there are no much comments but it's a simple Segment Tree with lazy propagation, if you don't understand something, please tell me!

NOTE: st[p].zero contains elements that are 0 mod 3 in interval stored in index p, st[p].one elements 1 mod 3, and st[p].two elements 2 mod 3; When I update I shift of one position these elements(0->1, 1->2, 2->0) and I use lazy. On updating, I return a pair < int , pair< int, int > >, just a simple way to store a triple of numbers, In this way a can return the difference of numbers 0,1,2 mod 3.

int sol;

struct mod{
    mod(){ zero=0; one=0;two=0;}
    int zero;
    int one;
    int two;  
};

class SegmentTree {         
public: int lazy[MAX_N];
  mod st[MAX_N];    
  int n;        
  int left (int p) { return p << 1; }     
  int right(int p) { return (p << 1) + 1; }

  void build(int p, int L, int R){
        if(L == R)
            st[p].zero=1;
        else{
            st[p].zero = R - L + 1;
            build(left(p), L, (L + R) / 2);
            build(right(p), ((L + R) / 2) + 1, R);
        }
        return;
  }

  void query(int p, int L, int R, int i, int j) {            
    if (L > R || i > R || j < L) return; 

    if(lazy[p]!=0){     // Check if this no has to be updated
        for(int k=0;k<lazy[p];k++){
            swap(st[p].zero,st[p].two);
            swap(st[p].one, st[p].two);
        }
        if(L != R){
            lazy[left(p)] = (lazy[left(p)] + lazy[p]) % 3;
            lazy[right(p)] = (lazy[right(p)] + lazy[p]) % 3;
        }
        lazy[p] = 0;
    } 


    if (L >= i && R <= j) { sol += st[p].zero;   return; }              


    query(left(p) , L              , (L+R) / 2, i, j);
    query(right(p), (L+R) / 2 + 1, R          , i, j);

    return; 
  }          

  pair < int, ii > update_tree(int p, int L, int R, int i, int j) {

    if (L > R || i > R || j < L){
      pair< int, pair< int, int > >  PP; PP.first=PP.second.first=PP.second.second=INF;
      return PP;
    }

    if(lazy[p]!=0){     // Check if this no has to be updated
        for(int k=0;k<lazy[p];k++){
            swap(st[p].zero,st[p].two);
            swap(st[p].one, st[p].two);
        }
        if(L != R){
            lazy[left(p)] = (lazy[left(p)] + lazy[p]) % 3;
            lazy[right(p)] = (lazy[right(p)] + lazy[p]) % 3;
        }
        lazy[p] = 0;
    } 

    if(L>=i && R<=j){
        swap(st[p].zero, st[p].two);
        swap(st[p].one, st[p].two);
        if(L != R){
            lazy[left(p)] = (lazy[left(p)] + 1) % 3;
            lazy[right(p)] = (lazy[right(p)] + 1) % 3;
        }
        pair< int, pair< int, int > > t; t.first = st[p].zero-st[p].one; t.second.first = st[p].one-st[p].two; t.second.second = st[p].two-st[p].zero;
        return t;
    }

    pair< int, pair< int, int > > s = update_tree(left(p), L, (L+R)/2, i, j); // Updating left child
    pair< int, pair< int, int > > s2 = update_tree(right(p), 1+(L+R)/2, R, i, j); // Updating right child
    pair< int, pair< int, int > > d2;
    d2.first = ( (s.first!=INF ? s.first : 0) + (s2.first!=INF ? s2.first : 0) ); // Calculating difference from the ones given by the children
    d2.second.first = ( (s.second.first!=INF ? s.second.first : 0) + (s2.second.first!=INF ? s2.second.first : 0) );
    d2.second.second = ( (s.second.second!=INF ? s.second.second : 0) + (s2.second.second!=INF ? s2.second.second : 0) );
    st[p].zero += d2.first; st[p].one += d2.second.first; st[p].two += d2.second.second; // Updating root 
    return d2;  // Return difference
  }

  public:
  SegmentTree(const vi &_A) {
    n = (int)_A.size();            
    build(1, 0, n - 1);                                  
  }

  void query(int i, int j) { return query(1, 0, n - 1, i, j); }   

  pair< int, pair< int, int > > update_tree(int i, int j) {
    return update_tree(1, 0, n - 1, i, j); }
};


int N,Q;

int main() {
    FILE * in; FILE * out;
    in = fopen("input.txt","r"); out = fopen("output.txt","w");

    fscanf(in, "%d %d" , &N, &Q);
    //cin>>N>>Q;
    int arr[N];
    vi A(arr,arr+N);

    SegmentTree *st = new SegmentTree(A);

    for(int i=0;i<Q;i++){
        int t,q,q2; 
        fscanf(in, "%d %d %d " , &t, &q, &q2);
        //cin>>t>>q>>q2;
        if(q > q2) swap(q, q2);
        if(t){
            sol=0;
            st->query(q,q2);
            fprintf(out, "%d\n", sol);           
            //cout<<sol<<endl;
        }
        else{
            pair<int, pair< int, int > > t = st->update_tree(q,q2);
        }
    }

    fclose(in); fclose(out);
    return 0;
}
2

There are 2 answers

2
tmyklebu On

It seems that you never have to care about the values of the elements, only their values modulo 3.

Keep a segment tree, using lazy updates as you suggest. Each node knows the number of things that are 0, 1, and 2 modulo 3 (memoization).

Each update hits log(n) nodes. When an update hits a node, you remember that you have to update the descendants (lazy update) and you cycle the memoized number of things in the subtree that are 0, 1, and 2 modulo 3.

Each query hits log(n) nodes; they're the same nodes an update of the same interval would hit. Whenever a query comes across a lazy update that hasn't been done, it pushes the update down to the descendants before recursing. Apart from that, all it does is it adds up the number of elements that are 0 modulo 3 in each maximal subtree completely contained in the query interval.

5
kraskevich On

You can store two values in each node:
1)int count[3] - how many there are 0, 1 and 2 in this node's segment.
2)int shift - shift value(initially zero).

The operations are performed in the following way(I use pseudo code):

add_one(node v)
    v.shift += 1
    v.shift %= 3

propagate(node v)
    v.left_child.shift += v.shift
    v.left_child.shift %= 3
    v.right_child.shift += v.shift
    v.right_child.shift %= 3 
    v.shift = 0
    for i = 0..2:
        v.count[i] = get_count(v.left, i) + get_count(v.right, i)

get_count(node v, int remainder)
    return v.count[(remainder + v.shift) % 3]

The number of elements divisible by 3 for a node v is get_count(v, 0). Update for a node is add_one operation. In general, it can be used as an ordinary segment tree(to answer range queries).

The entire tree update looks like that:

update(node v, int left, int right)
    if v is fully covered by [left; right]
        add_one(v)
    else:
        propagate(v)
        if [left; right] intersects with the left child:
            update(v.left, left, right)
        if[left; right] intersects with the right child:
            update(v.right, left, right)
        for i = 0..2:
            v.count[i] = get_count(v.left, i) + get_count(v.right, i)

Getting the number of elements divisible by 3 is done in similar manner.