#include "oagSswAiBitVectorDict.h"

using namespace oagAi;

namespace oagSsw
{

AiBitVectorDict::AiBitVectorDict(AiSimEngine &simEngine)
: simEngine(simEngine)
{
    trie = new TrieLeaf;
    maxBinSize = UINT_MAX;
}

void
AiBitVectorDict::initialize(const vector<oaUInt4>   &bitOrder,
                            oaUInt4                 maxBinSize)
{
    delete trie;
    trie = new TrieLeaf;
    this->bitOrder = bitOrder;
    this->maxBinSize = maxBinSize;
}

AiBitVectorDict::~AiBitVectorDict()
{
    delete trie;
}

void
AiBitVectorDict::insert(oagAi::Ref x)
{
    trie = insertRec(x, trie, 0);
}

AiBitVectorDict::TrieNode *
AiBitVectorDict::insertRec(oagAi::Ref   x,
                           TrieNode     *trie,
                           oaUInt4      depth)
{
    if (trie->isLeaf) {
        TrieLeaf *leaf = static_cast<TrieLeaf*>(trie);

        if (leaf->bin.size() < maxBinSize || depth >= bitOrder.size()) {
            leaf->bin.push_back(x);
            return leaf;
        } else {
            // Create a new branch.
            TrieLeaf    *child0 = new TrieLeaf;
            TrieLeaf    *child1 = new TrieLeaf;
            TrieBranch  *branch = new TrieBranch(child0, child1);
            
            for (TrieLeaf::Bin::iterator it = leaf->bin.begin();
                 it != leaf->bin.end();
                 ++it) {
                trie = insertRec(*it, branch, depth);
            }
            delete leaf;
            trie = insertRec(x, branch, depth);
            return trie;
        }
    } else {
        TrieBranch  *branch = static_cast<TrieBranch*>(trie);
        
        assert(depth < bitOrder.size());
        oaUInt4     branchIdx   = simEngine.getBit(x, bitOrder[depth]);
        TrieNode    *child      = branch->child[branchIdx];

        assert(child);
        branch->child[branchIdx] = insertRec(x, child, depth + 1);
        return branch;
    }
}

void
AiBitVectorDict::search(oagAi::Ref          x,
                        const BitVector     &mask,
                        list<oagAi::Ref>    &matches) const
{
    searchRec(x, mask, trie, 0, matches);
}

void
AiBitVectorDict::searchRec(oagAi::Ref       x,
                           const BitVector  &mask,
                           TrieNode         *trie,
                           oaUInt4          depth,
                           list<oagAi::Ref> &matches) const
{
    if (trie->isLeaf) {
        TrieLeaf *leaf = static_cast<TrieLeaf*>(trie);
        
        matches.insert(matches.end(), leaf->bin.begin(), leaf->bin.end());
    } else {
        TrieBranch *branch = static_cast<TrieBranch*>(trie);
        oaUInt4     bitIdx = bitOrder[depth];
        bool        maskBit = mask.getBit(bitIdx);

        if (maskBit) {
            oaUInt4 branchIdx = simEngine.getBit(x, bitIdx);

            searchRec(x, mask, branch->child[branchIdx], depth + 1, matches);
        } else {
            searchRec(x, mask, branch->child[0], depth + 1, matches);
            searchRec(x, mask, branch->child[1], depth + 1, matches);
        }
    }
}

AiBitVectorDict::TrieBranch::TrieBranch(TrieNode *child0,
                                        TrieNode *child1)
: TrieNode(false)
{
    child[0] = child0;
    child[1] = child1;
}

AiBitVectorDict::TrieBranch::~TrieBranch()
{
    delete child[0];
    delete child[1];
}

}
