#include <cstdlib>
#include <iostream>
#include <iomanip>
#include "oagSswAiGraphUtil.h"
#include "oagSswSatSweepingEngine.h"
#include "oagSswAiSatSolver.h"
#include "oagSswLogger.h"

using namespace std;
using namespace oa;
using namespace oagAi;

namespace oagSsw
{

#define QUIT_ON_INTERNAL_ERROR { assert(false); cerr << "ERROR: Internal error" << endl; exit(0); }

const string                LOGGER_NAME = "oagSsw";

const oa::oaDouble          SatSweepingEngine::BACKTRACKS_INCR_FACTOR = 2.0;

SatSweepingEngine::SatSweepingEngine(oagAi::Graph &graph)
: graph(graph),
  levelizer(graph),
  simEngine(graph)
{
    list<Ref> allNodes;

    graph.getAll(allNodes);
    nodes.insert(allNodes.begin(), allNodes.end());
    nodes.insert(graph.constantZero());
    addInputsFrom(allNodes);

    levelizer.initialLevels();

    oaUInt4 maxRef = 0;
    for (list<Ref>::iterator it = allNodes.begin();
         it != allNodes.end();
         ++it) {
        if ((oaUInt4)(*it) > maxRef)
            maxRef = (oaUInt4)(*it);
    } 
    simEngine.setSimVectorListLength(maxRef + 2);
 
    logger = Logger::getInstance(LOGGER_NAME);

    clearStats();
}

SatSweepingEngine::SatSweepingEngine(oagAi::Graph           &graph, 
                                     const list<oagAi::Ref> &nodes)
: graph(graph),
  levelizer(graph),
  simEngine(graph)
{
    for (list<Ref>::const_iterator it = nodes.begin();
         it != nodes.end();
         ++it) {
        this->nodes.insert(graph.getNonInverted(*it));
    }
    this->nodes.insert(graph.constantZero());
    addInputsFrom(nodes);

    printf("before initilize levels\n");
    levelizer.initialLevels();
    printf("after initilize levels\n");

    list<Ref> allNodes;
    graph.getAll(allNodes);
    oaUInt4 maxRef = 0;
    for (list<Ref>::iterator it = allNodes.begin();
         it != allNodes.end();
         ++it) {
        if ((oaUInt4)(*it) > maxRef)
            maxRef = (oaUInt4)(*it);
    } 
    
    simEngine.setSimVectorListLength(maxRef + 2);
 
    logger = Logger::getInstance(LOGGER_NAME);

    clearStats();
}

void
SatSweepingEngine::clearStats()
{
    stats.numMerges = 0;

    stats.sat.numSatisfiable    = 0;
    stats.sat.numUnsatisfiable  = 0;
    stats.sat.numUndecidable    = 0;
    stats.sat.numBacktracks     = 0;
    
    stats.runtime.simulation    = 0.0;
    stats.runtime.satChecks     = 0.0;
    stats.runtime.computeClusters = 0.0;
    stats.runtime.sortClusters  = 0.0;
}

void
SatSweepingEngine::addInputsFrom(const list<Ref> &pool)
{
    for (list<Ref>::const_iterator it = pool.begin(); it != pool.end(); ++it) {
        Ref x = *it;
        switch (graph.getNodeType(x)) {
          case Node::TERMINAL:
            if (!graph.isNull(graph.getTerminalDriver(x))) {
                break;
            }
            // fall through for null-driven terminals
          case Node::SEQUENTIAL:
            inputs.insert(graph.getNonInverted(x));
            break;
            
          default:
            ; // do nothing
        }
    }
}

void
SatSweepingEngine::levelize()
{
    levels.clear();
    for (Set<Ref>::iterator it = nodes.begin(); it != nodes.end(); ++it) {
        Ref x = *it;
        oaUInt4 level = levelizer.levelOf(x);
        if (levels.size() < level + 1) {
            levels.resize(level + 1);
        }
        levels[level].insert(graph.getNonInverted(x));
    }
}

bool 
SatSweepingEngine::hasRep(oagAi::Ref x)
{
    x = graph.getNonInverted(x);
    return reps.find(x) != reps.end();
}

oagAi::Ref
SatSweepingEngine::getRep(oagAi::Ref x)
{
    x = graph.getNonInverted(x);

    Map<oagAi::Ref, oagAi::Ref>::iterator it = reps.find(x);
    if (it == reps.end()) {
        return graph.getNull();
    } else {
        Ref rep = it->second;
        if (hasRep(rep)) {
            rep = getRep(rep);
            reps[x] = rep;
        }
        return rep;
    }
}

void
SatSweepingEngine::run()
{
    OAGSSW_LOGLN(logger, "Running SAT sweeping...");
    OAGSSW_LOGLN(logger, nodes.size() << " nodes");
    OAGSSW_LOGLN(logger, inputs.size() << " inputs");

    oaTimer timer;

    activeNodes.clear();
    activeNodes.insert(activeNodes.end(), nodes.begin(), nodes.end());

    //levelize();
    initializeSimVectors();
    
    oaUInt4 prevNumActive   = activeNodes.size();
    oaUInt4 prevNumClusters = 1;
    while (true) {
        {
            oaTimer timer;
            simEngine.simulateAll();
            stats.runtime.simulation += timer.getElapsed();
        }

        updateActive();
        computeClusters();
        
        if (clusters.empty()) {
            break;
        }
        if (prevNumActive == activeNodes.size()
            && prevNumClusters == clusters.size()) {
            OAGSSW_LOGLN(logger, "increase resource");
            increaseResources();
        }
        if (clusters.at(0).size() <= 1 || reachedResourceLimit()) {
            break;
        }
        
        prevNumActive = activeNodes.size();
        prevNumClusters = clusters.size();

        refineClusters();
    }

    OAGSSW_LOGLN(logger, stats.numMerges << " nodes were merged.");
    if (logger->isEnabled()) {
        oaUInt4 numAndsMerged = 0;
        for (Set<Ref>::iterator it = nodes.begin(); it != nodes.end(); it++) {
            if (hasRep(*it)) { 
                numAndsMerged++;
            }
        }
        OAGSSW_LOGLN(logger, numAndsMerged << " And nodes were merged.");
    }
    OAGSSW_LOGLN(logger, "Finished SAT sweeping.");
    OAGSSW_LOGLN(logger, "Elapsed time: " << setprecision(3) << timer.getElapsed() << "s");
    OAGSSW_LOGLN(logger, "SAT check time: " << setprecision(3) << stats.runtime.satChecks << "s");
    OAGSSW_LOGLN(logger, "Simulation time: " << setprecision(3) << stats.runtime.simulation << "s");
    OAGSSW_LOGLN(logger, "compute cluster time: " << setprecision(3) << stats.runtime.computeClusters << "s");
    OAGSSW_LOGLN(logger, "sort cluster time: " << setprecision(3) << stats.runtime.sortClusters << "s");
    OAGSSW_LOGLN(logger, "sat " << stats.sat.numSatisfiable << " unsat" << stats.sat.numUnsatisfiable << " undecidable " << stats.sat.numUndecidable);
}


void
SatSweepingEngine::initializeSimVectors()
{
    oaUInt4 numBits = params.initNumSimBits;

    simEngine.setLength(numBits);
    simEngine.randomizeVars(0, numBits - 1);
    {
        oaTimer timer;
        simEngine.simulateAll();
        stats.runtime.simulation += timer.getElapsed();
    }
}


void
SatSweepingEngine::updateActive()
{
    HasRep pred(*this);

    vector<Ref>::iterator newEnd = remove_if(activeNodes.begin(), 
                                             activeNodes.end(), pred);
    activeNodes.erase(newEnd, activeNodes.end());
}


bool
SatSweepingEngine::LessNode::operator() (oagAi::Ref x,
                                         oagAi::Ref y) const
{
    oaInt4 simCmp = engine.simEngine.compareRef(x, y, true);
    if (simCmp < 0) {
        return true;
    } else if (simCmp > 0) {
        return false;
    }

    oaUInt4 xLevel = engine.levelizer.levelOf(x);
    oaUInt4 yLevel = engine.levelizer.levelOf(y);
    if (xLevel < yLevel) {
        return true;
    } else if (xLevel > yLevel) {
        return false;
    }

    return (x < y);
}


void
SatSweepingEngine::computeClusters()
{
    oaTimer timer;
    {
        LessNode predicate(*this);
        sort(activeNodes.begin(), activeNodes.end(), predicate);
    }
    stats.runtime.sortClusters += timer.getElapsed();

    //cout << "after sort active nodes " << timer.getElapsed() << endl;
    
    clusters.clear();
    Cluster curCluster(0, 0);

    for (oaUInt4 i = 1; i < activeNodes.size(); ++i) {
        Ref     x = activeNodes[i];
        Ref     y = activeNodes[curCluster.begin];
        oaInt4  cmp = simEngine.compareRef(x, y, true);

        if (cmp != 0) {
            curCluster.end = i;
            clusters.push_back(curCluster);
            curCluster.begin = i;
        }
    }
    curCluster.end = activeNodes.size();
    clusters.push_back(curCluster);

    //cout << "compare vectors" << timer.getElapsed() << endl;

   //cout << "after sort clusters" << timer.getElapsed() << endl;
    
    stats.runtime.computeClusters += timer.getElapsed();
}


inline bool
SatSweepingEngine::IsSingleton::operator() (const Cluster &cluster) const
{
    return cluster.size() == 1;
}

inline bool
SatSweepingEngine::LessCluster::operator() (const Cluster &c1,
                                            const Cluster &c2) const
{
    assert(c1.size() > 1);
    assert(c2.size() > 1);

    Ref     ref1 = engine.activeNodes.at(c1.begin + 1);
    Ref     ref2 = engine.activeNodes.at(c2.begin + 1);
    oaUInt4 level1 = engine.levelizer.levelOf(ref1);
    oaUInt4 level2 = engine.levelizer.levelOf(ref2);

    return level1 < level2;
}


bool
SatSweepingEngine::reachedResourceLimit()
{
    return resources.numSatBacktracks > params.maxCumSatBacktracks;
}


void
SatSweepingEngine::increaseResources()
{
        resources.numSatBacktracks = oaUInt4(resources.numSatBacktracks
                                         * BACKTRACKS_INCR_FACTOR);
}


void
SatSweepingEngine::refineClusters()
{
    // How many clusters should we refine before re-simulating?
    // How many nodes should we compare in each cluster?
    // We could use a more sophisticated scheme to make these decisions, but
    // for now we'll do the simple thing: Compare one pair in one cluster.
    {
        IsSingleton predicate;
        vector<Cluster>::iterator newEnd = remove_if(clusters.begin(), 
                                                     clusters.end(),
                                                     predicate);
        clusters.erase(newEnd, clusters.end());
    }

    {
        LessCluster predicate(*this);
        sort(clusters.begin(), clusters.end(), predicate);
    }
 
    //Cluster cluster = clusters.front();
    for (oaUInt4 i = 0; i < clusters.size(); i++) {
        Cluster cluster = clusters[i];
        assert(cluster.size() > 1);

        Ref u = activeNodes[cluster.begin];
        Ref v = activeNodes[cluster.begin + 1];

        // Adjust for phase.
        if (simEngine.getBit(u, 0)) {
            u = graph.notOf(u);
        }
        if (simEngine.getBit(v, 0)) {
            v = graph.notOf(v);
        }

        Map<Ref, bool>::Type assignment;
        if (verifyEquivalence(u, v, assignment)) {
            merge(u, v);
        } else if (assignment.size() > 0) {
            extendSimVectors(assignment);
        } 
    }
}


bool
SatSweepingEngine::verifyEquivalence(oagAi::Ref                     x,
                                     oagAi::Ref                     y,
                                     Map<oagAi::Ref, bool>::Type    &assignment)
{
    oaTimer timer;

    AiSatSolver satSolver(graph); 
    
    Set<Ref>::Type alreadyAddedTransitiveFanin;
    
    addTransitiveFaninToSolver(satSolver, x, alreadyAddedTransitiveFanin);
    addTransitiveFaninToSolver(satSolver, y, alreadyAddedTransitiveFanin);

    satSolver.addNotGate(satSolver.litOf(x), 
                         satSolver.litOf(y)); 

    lbool res = satSolver.solve(resources.numSatBacktracks);
    bool retVal;

    if (res == l_True) {
        // SAT
        stats.sat.numSatisfiable++;

        for (Set<Ref>::iterator it = inputs.begin();
             it != inputs.end();
             ++it) {
            Ref x = graph.getNonInverted(*it);
            lbool value = satSolver.getValue(x);
            if (value == l_True) {
                assignment[x] = true;
            } else if (value == l_False) {
                assignment[x] = false;
            }
        }
        retVal = false;
    } else if (res == l_False) {
        //UNSAT
        stats.sat.numUnsatisfiable++;
        retVal = true;
    } else {
        //UNDECIDED
        stats.sat.numUndecidable++;
        retVal = false;
    }

    stats.runtime.satChecks += timer.getElapsed();

    return retVal;
}


void
SatSweepingEngine::addTransitiveFaninToSolver(AiSatSolver           &satSolver,
                                              oagAi::Ref            node,
                                              Set<oagAi::Ref>::Type &alreadyAddedTransitiveFanin)
{
    if (alreadyAddedTransitiveFanin.find(node) !=  alreadyAddedTransitiveFanin.end()) {
        return;
    }
    alreadyAddedTransitiveFanin.insert(node);
    Ref left, right, driver;
    switch(graph.getNodeType(node)) {
      case Node::AND:
        left = graph.getAndLeft(node);
        right = graph.getAndRight(node);
        addTransitiveFaninToSolver(satSolver, left, alreadyAddedTransitiveFanin);
        addTransitiveFaninToSolver(satSolver, right, alreadyAddedTransitiveFanin);
        satSolver.addAndGate(satSolver.litOf(left),
                             satSolver.litOf(right),
                             satSolver.litOf(graph.getNonInverted(node)));
        break;
      case Node::TERMINAL:
        driver = graph.getTerminalDriver(node);
        if (!graph.isNull(driver)) {
            addTransitiveFaninToSolver(satSolver, driver, alreadyAddedTransitiveFanin);
            satSolver.addNotGate(satSolver.litOf(driver),
                                 satSolver.litOf(graph.notOf(graph.getNonInverted(node))));
        }
        break;
      case Node::SEQUENTIAL:
      case Node::CONSTANT0:
        break;
      default:
        QUIT_ON_INTERNAL_ERROR;
    }
}


void
SatSweepingEngine::extendSimVectors(const Map<oagAi::Ref, bool>::Type &assignment)
{
    oaUInt4 i = simEngine.getLength();
    simEngine.setLength(i + 1);

    for (Set<Ref>::iterator it = inputs.begin(); it != inputs.end(); ++it) {
        Ref x = graph.getNonInverted(*it);

        Map<Ref, bool>::const_iterator it2 = assignment.find(x);
        if (it2 != assignment.end()) {
            bool bit = it2->second;
            simEngine.setBit(x, i, bit);
        } else if (x == graph.constantZero()) {
            simEngine.setBit(x, i, false);
        } else {
            bool bit = ((rand() & 1) == 0);
            simEngine.setBit(x, i, bit);
        }
    }
}

void
SatSweepingEngine::merge(oagAi::Ref  x,
                            oagAi::Ref  replacement)
{
    mergeRec(x, replacement);
    //levelizer.clear();
}

void
SatSweepingEngine::mergeRec(oagAi::Ref  x,
                            oagAi::Ref  replacement)
{
    // Only merge nodes that the user specified to be merged.
    if (nodes.find(graph.getNonInverted(x)) == nodes.end()) {
        return;
    }

    if (hasRep(x) || hasRep(replacement))
        return;

    if (levelizer.levelOf(x) < levelizer.levelOf(replacement)) {
        Ref tmp;
        tmp = x;
        x = replacement;
        replacement = tmp;
    }
 
    if (nodes.find(graph.getNonInverted(x)) == nodes.end()) {
        return;
    }
    OAGSSW_LOGLN(logger, "merge " << x << " with " << replacement);
    graph.resubstitute(x, replacement);

    levelizer.updateLevel(replacement);
    stats.numMerges++;

    reps[graph.getNonInverted(x)] = replacement;

    typedef pair<Ref, Ref> RefPair;

    map<RefPair, Ref> fanoutHash;
    list<RefPair> toMerge;
    const list<Ref> &fanout = graph.getFanout(replacement);
    for (list<Ref>::const_iterator it = fanout.begin();
            it != fanout.end();
            ++it) {
        Ref out = *it;

        if (graph.getNodeType(out) != Node::AND || hasRep(out)) {
            continue;
        }
    
        RefPair inputsLR    = RefPair(graph.getAndLeft(out), graph.getAndRight(out));
        Ref     hashedLR    = fanoutHash[inputsLR];
        if (graph.isNull(hashedLR)
                || levelizer.levelOf(hashedLR) > levelizer.levelOf(out)) {
            fanoutHash[inputsLR] = out;
        }

        RefPair inputsRL    = RefPair(graph.getAndRight(out), graph.getAndLeft(out));
        Ref     hashedRL    = fanoutHash[inputsRL];
        if (graph.isNull(hashedRL)
                || levelizer.levelOf(hashedRL) > levelizer.levelOf(out)) {
            fanoutHash[inputsRL] = out;
        }
    }

    for (list<Ref>::const_iterator it = fanout.begin();
            it != fanout.end();
            ++it) {
        Ref out = *it;

        if (graph.getNodeType(out) != Node::AND || hasRep(out)) {
            continue;
        }
    
        RefPair inputs  = RefPair(graph.getAndLeft(out), graph.getAndRight(out));
        Ref     hashed  = fanoutHash[inputs];
        if (hashed != out) {
            toMerge.push_back(RefPair(out, hashed));
        }
    }

    for (list<RefPair>::iterator it = toMerge.begin();
         it != toMerge.end();
         ++it) {
        mergeRec(it->first, it->second);
    }
}

}
