How to Implement Trie Delete Function Without Overlapping Error

94 views Asked by At

I'm trying to implement a Trie in c++ in order to solve this problem https://codeforces.com/problemset/problem/706/D, and I've gotten everything down except for the delete function. For some reason, even though my code checks to make sure that we don't delete necessary elements, it still does so. I even followed the Digital Ocean trie explanation but that didn't help either. This results in WA for test case #8.

This is my current code, and if you scroll down a bit, you'll find the delete function and its helper methods.

#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
 
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
const ll INF = 100000000000;
const ll MOD = 1000000007;
const int MAX_N = 1000005;
 
using namespace __gnu_pbds;
template<typename T> using ordered_set = tree<T, null_type, less<T>, 
rb_tree_tag, tree_order_statistics_node_update>;
 
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());

typedef struct trie_node trie_node;

char *decimal_binary(int val) {
    char *bin_rep = (char*)(calloc(32, sizeof(char)));
    for(int i = 31; i >= 0; --i) {
        if(val & (1 << i)) {
            bin_rep[i] = '1';
        } else {
            bin_rep[i] = '0';
        }
    }
    reverse(bin_rep, bin_rep+32);
    return bin_rep;
}

int binary_decimal(char *bs) {
    int sum = 0;
    for(int i = 31, j = 0; i >= 0; --i, ++j) {
        sum += ((int)(bs[i]-'0'))*pow(2, j);
    }
    return sum;
}
 
struct trie_node {
    trie_node *children[2];
    bool is_leaf = false;
};
 
trie_node *make_node() {
    trie_node *node = new trie_node;
    for(int i = 0; i < 2; ++i) {
        node->children[i] = NULL;
    }
    node->is_leaf = false;
    return node;
}
 
void unload_node(trie_node *node) {
    for(int i = 0; i < 2; ++i) {
        if(node->children[i] != NULL) {
            unload_node(node->children[i]);
        } else {
            continue;
        }
    }
    free(node);
}
 
trie_node *insert_node(trie_node *root, char *bs) {
    trie_node *temp = root;
    for(int i = 0; bs[i] != '\0'; ++i) {
        int idx = (int)(bs[i]-'0');
        if(temp->children[idx] == NULL) {
            temp->children[idx] = make_node();
        }
        temp = temp->children[idx];
    }
    temp->is_leaf = true;
    return root;
}

bool check_leaf(trie_node *root, char *bs) {
    trie_node *temp = root;
    for(int i = 0; bs[i]; ++i) {
        int idx = (int)bs[i]-'0';
        if(temp->children[idx] != NULL) {
            temp = temp->children[idx];
        }
    }
    return temp->is_leaf;
}

int earliest_branch(trie_node *root, char *bs) {
    trie_node *temp = root;
    int n = strlen(bs);
    if(n == 0) return 0;
    int last_idx = 0;
    for(int i = 0; i < n; ++i) {
        int idx = bs[i]-'0';
        if(temp->children[idx]) {
            for(int j = 0; j < 2; ++j) {
                if(j != idx && temp->children[j] != NULL) {
                    last_idx = i+1;
                    break;
                }
            }
            temp = temp->children[idx];
        }
    }
    return last_idx;
}

char *longest_prefix(trie_node *root, char *bs) {
    if(!bs || bs[0] == '\0') return NULL;
    
    int n = strlen(bs);
    char *lgt_prefix = (char*)(calloc(n+1, sizeof(char)));
    for(int i = 0; bs[i] != '\0'; ++i) {
        lgt_prefix[i] = bs[i];
    }
    lgt_prefix[n] = '\0';
    
    int branch_idx = earliest_branch(root, lgt_prefix)-1;
    if(branch_idx >= 0) {
        lgt_prefix[branch_idx] = '\0';
        lgt_prefix = (char*)(realloc(lgt_prefix, (branch_idx+1)*sizeof(char)));
    }
    
    return lgt_prefix;
}

trie_node *delete_node(trie_node *root, char *bs) {
    if(!root) return NULL;
    if(!bs || bs[0] == '\0') return root;
    if(!check_leaf(root, bs)) return root;
    
    trie_node *temp = root;
    char *lgt_prefix = longest_prefix(root, bs);
    if(lgt_prefix[0] == '\0') {
        free(lgt_prefix);
        return root;
    }
    int pos;
    for(pos = 0; lgt_prefix[pos] != '\0'; ++pos) {
        int idx = (int)lgt_prefix[pos]-'0';
        if(temp->children[idx] != NULL) {
            temp = temp->children[idx];
        } else {
            free(lgt_prefix);
            return root;
        }
    }
    int n = strlen(bs);
    for(; pos < n; ++pos) {
        int idx = (int)bs[pos]-'0';
        if(temp->children[idx]) {
            trie_node *extra = temp->children[idx];
            temp->children[idx] = NULL;
            unload_node(extra);
        }
    } 
    free(lgt_prefix);
    return root;
}
 
char *search_trie(trie_node *root, char *bs) {
    char *res = (char*)(calloc(32, sizeof(char)));
    for(int i = 0; i < 32; ++i) {
        res[i] = '0';
    }
    trie_node *temp = root;
    for(int i = 0; i < 32; ++i) {
        int idx = (((int)(bs[i]-'0'))+1)%2;
        if(temp->children[idx] != NULL) {
            res[i] = '1';
            temp = temp->children[idx];
        } else if(temp->children[(idx+1)%2] != NULL){
            temp = temp->children[(idx+1)%2];
        } else {
            break;
        }
    }
    return res;
}

int main() {
    cin.tie(0)->sync_with_stdio(0);
    int t;
    cin >> t;
    
    trie_node *root = make_node();
    map<int, int> cnt;
    char *tmpbs = decimal_binary(0);
    root = insert_node(root, tmpbs);
    while(t--) {
        char type; int val;
        cin >> type >> val;
        char *bs = decimal_binary(val);
        if(type == '+') {
            if(cnt[val] == 0) {
                root = insert_node(root, bs);
            }
            ++cnt[val];
        } else if(type == '-') {
            --cnt[val];
            if(cnt[val] == 0) {
                root = delete_node(root, bs);
            }
        } else {
            char *res = search_trie(root, bs);
            int ans = binary_decimal(res);
            cout << ans << "\n";
        }
    }
    unload_node(root);
    
    return 0;
}

Thanks!

UPD 1: Here is a test case where my code fails

14
? 1
+ 1
+ 7
? 2
+ 3
? 1
? 6
+ 4
+ 8
- 8
+ 6
+ 6
- 6
? 3

My output:

1
5
6
7
5

Correct output:

1
5
6
7
7
0

There are 0 answers