#include "oaDesignDB.h"
#include "oagAiGraph.h"
#include "oagRedun.h"
#include "oagRedunCircuitSAT.h"
#include "oagRedunFlattener.h"
#include "oagRedunRTG.h"
#include <memory>

#define DEBUG(x) 
#ifndef NDEBUG
#define assertB(x) for (;!(x); assert(x)) 
// {}
#else
#define assertB(x) if (0) 
// {}
#endif
#define foreach(i,x) for(typeof((x).begin()) i=(x).begin(); i!=(x).end(); ++i)

using namespace std;
using namespace oa;
using namespace oagAi;

namespace {
oaNativeNS ns;
oaString   str;

template <class T>
const char *
getName(T* inst) {
    inst->getName(ns, str);
    return str;
}

template <>
const char *
getName(oaOccInstTerm* it) {
    oaString s1;
    it->getInst()->getName(ns, str);
    it->getTerm()->getName(ns, s1);
    str += ':';
    str += s1;
    return str;
}

template <>
const char *
getName(oaInstTerm* it) {
    oaString s1;
    it->getInst()->getName(ns, str);
    it->getTerm()->getName(ns, s1);
    str += ':';
    str += s1;
    return str;
}
}

namespace oagRedun {

struct Fault {
    oa::oaInstTerm *pin;
    oagAi::Ref         ref;
    bool               SA1;
    Fault::Fault(oa::oaInstTerm *p, oagAi::Ref r, bool b)
        : pin(p), ref(r), SA1(b) { }
};

void
createProblemInstance(oaDesign *design, 
                      oagAi::Graph &graph, 
                      vector<oagAi::Ref> &cis, 
                      vector<oagAi::Ref> &cos, 
                      list<Fault> &faults)
{ 
   
     
    Flattener flat(design, graph, cis, cos);
  
    string str = "flat";
    graph.dot(str);
    oaIter<oaInst> iter(design->getTopBlock()->getInsts());
    while (oaInst *inst = iter.getNext()) {
        oaInstTerm *term;
        oaIter<oaInstTerm> termIter(inst->getInstTerms(oacTermIterSingleBit));
        while ((term = termIter.getNext())) {
            if (term->getTerm()->getTermType() == oacInputTermType) {
                oagAi::Ref ref = flat.getTerm(term);
                const list<oagAi::Ref> &fo = graph.getFanout(ref);
		
		//string str ("one");
		//graph.dot(str);
		
                assert(graph.getNodeType(ref) == oagAi::Node::TERMINAL);

                if (fo.size() == 1) {
                    oagAi::Ref theFo = fo.front();
                    unsigned test = 1;
                    if (graph.getNodeType(theFo) == oagAi::Node::AND) {
                        // AND
                        oagAi::Ref l = graph.getAndLeft (theFo);
                        oagAi::Ref r = graph.getAndRight(theFo);
                        if (oagAi::Graph::getNonInverted(l) == ref) {
                            assert(oagAi::Graph::getNonInverted(r) != ref);
                            test = !oagAi::Graph::isInverted(l);
                        } else if (oagAi::Graph::getNonInverted(r) == ref) {
                            test = !oagAi::Graph::isInverted(r);
                        } else assert(0);

                        faults.push_back(Fault(term, ref, test));
                    } else {
                        ;
                    }
                } else if (fo.size() > 1) {
                    // XOR
                    faults.push_back(Fault(term, ref, 0));
                    faults.push_back(Fault(term, ref, 1));
                }
            }
        }
    }
}

void
removeExtraTerminals(oagAi::Graph &g, oagAi::Ref ref, set<Ref> &sites)
{
    ref = g.getNonInverted(ref);

    if (g.isVisited(ref)) return;
    g.markVisited(ref);

    switch (g.getNodeType(ref)) {
        case oagAi::Node::CONSTANT0:
            break;
        case oagAi::Node::AND: 
            removeExtraTerminals(g, g.getAndLeft (ref), sites);
            removeExtraTerminals(g, g.getAndRight(ref), sites);
            break;
        case oagAi::Node::TERMINAL: {
            Ref ref0 = g.getTerminalDriver(ref);
            if (g.isNull(ref0)) return;
            if (sites.find(ref) == sites.end()) {
                g.removeEquivalent(ref);
                g.resubstitute(ref, ref0);
                g.setTerminalDriver(ref, g.getNull());
            }
            removeExtraTerminals(g, ref0, sites);
            break;
        }
        case oagAi::Node::NONE:
        case oagAi::Node::SEQUENTIAL:
            assert(0);
    }
}

void
removeExtraTerminals(oagAi::Graph &g, vector<oagAi::Ref> &cos, list<Fault> &faults)
{
    set<oagAi::Ref> faultSites;
    foreach (it, faults) 
        faultSites.insert(it->ref);
    foreach (it, cos)
        faultSites.insert(*it);
    g.newTraversalID();
    foreach (it, cos)
        removeExtraTerminals(g, *it, faultSites);
    g.rehash();
    g.garbageCollect();
}

void
doRTG(oagAi::Graph &graph, vector<oagAi::Ref> &cos, list<Fault> &faults, unsigned seed)
{
    auto_ptr<RTG> rtg(RTG::createOAG(graph, cos, seed));

    bool progress;
    do {
        rtg->randomSim();
        graph.newTraversalID();
        vector<oagAi::Ref> tested;
        for (list<Fault>::iterator it=faults.begin(); it!=faults.end(); ) {
            if (graph.getNodeType(it->ref) == oagAi::Node::NONE) {
                // garbage collected?
                list<Fault>::iterator tmp = it;
                ++it;
                faults.erase(tmp);
            } else if (rtg->isTested(it->ref, it->SA1)) {
                tested.push_back(it->ref);
                list<Fault>::iterator tmp = it;
                ++it;
                faults.erase(tmp);
            } else {
                graph.markVisited(it->ref);
                ++it;
            }
        }
        foreach (it, tested) {
            oagAi::Ref ref = *it;
            if (!graph.isVisited(*it)) {
                assert(graph.getTerminalDriver(ref) != graph.getNull());
                rtg->removeTestSite(ref);
                graph.markVisited(ref);
            }
        }
        progress = tested.size() > faults.size()/100.0;
    } while (progress);
}

/// Wrapper because STL doesn't understand virtual functions
struct StuckAtRedunCBWrapper {
    StuckAtRedunCB *cb;
    StuckAtRedunCBWrapper(StuckAtRedunCB *c) : cb(c) { }
    bool operator()(Fault &a, Fault &b) {
        return cb->pinSort(a.pin, b.pin);
    }
};

void
markReachedLogic(oagAi::Graph &graph, oagAi::Ref ref)
{
    if (graph.isVisited(ref)) return;
    graph.markVisited(ref);

    switch (graph.getNodeType(ref)) {
        case oagAi::Node::NONE:
        case oagAi::Node::CONSTANT0:
            return;
        case oagAi::Node::AND: {
            oagAi::Ref l = graph.getAndLeft (ref);
            oagAi::Ref r = graph.getAndRight(ref);
            markReachedLogic(graph, l);
            markReachedLogic(graph, r);
            return;
        }
        case oagAi::Node::TERMINAL: {
            oagAi::Ref l = graph.getTerminalDriver(ref);
            if (l) markReachedLogic(graph, l);
            return;
        }
        case oagAi::Node::SEQUENTIAL:
            assert(0);
    }
}

void
markDcone(oagAi::Graph &g, oagAi::Ref ref)
{
    if (g.isVisited(ref)) return;
    g.markVisited(ref);

    foreach (it, g.getFanout(ref))
        markDcone(g, *it);
}

void
removeExtraOutputs(oagAi::Graph &graph, vector<oagAi::Ref> &cos, list<Fault> &faults)
{
    graph.newTraversalID();
    foreach (it, faults) {
        markDcone(graph, it->ref);
    }
    for (vector<Ref>::iterator it=cos.begin(); it!=cos.end(); ) {
        if (!graph.isVisited(*it)) {
            graph.decrementExternalReferences(*it);
            *it = cos.back(); cos.pop_back();
        } else {
            ++it;
        }
    }
}

void 
removeStuckAtRedundancies(oa::oaDesign *design, StuckAtRedunCB *cb, unsigned seed)
{
    // get tests
    oagAi::Graph graph;
    vector<oagAi::Ref> cis, cos;
    list<Fault> faults;

    oaTimer timer;
    try {
        createProblemInstance(design, graph, cis, cos, faults);
    } catch (Flattener::FloatingNetException) {
        throw oagRedun::FloatingNetException();
    } catch (Flattener::CycleException) {
        throw oagRedun::CycleException();
    }

    removeExtraTerminals(graph, cos, faults);
    printf("Flattener runtime: %gs\n", timer.getElapsed());
    timer.reset();

    printf("%d faults before RTG\n", faults.size());
    doRTG(graph, cos, faults, seed);
    printf("RTG runtime: %gs\n", timer.getElapsed());
    timer.reset();
    printf("%d faults after RTG\n", faults.size());

    removeExtraOutputs(graph, cos, faults);
    graph.rehash();
    graph.garbageCollect();

    StuckAtRedunCBWrapper x(cb);
    faults.sort(x);

#ifdef OAGSAT
    auto_ptr<CircuitSAT> cs(CircuitSAT::createOAG(graph, cos));
#else
    auto_ptr<CircuitSAT> cs(CircuitSAT::createXYZ(graph, cos, true));
#endif
    cs->params.maxNumPropagations = 30000;

    unsigned redun=0, giveups=0, tested=0;
    foreach (it, faults) {
        const Fault &f = *it;
        DEBUG(printf("checking %s\n", getName(f.pin)));
        CircuitSAT::Result res = cs->isSubstitutable(f.SA1 ? graph.constantOne() : graph.constantZero(), f.ref);
        if (res == CircuitSAT::UNSAT)   {
            redun++;  
            cb->removeRedundancy(f.pin, f.SA1);
            if (f.pin->getNet()->getSigType() == (f.SA1 ? oacTieHiSigType : oacTieLoSigType)) {
                cs->substitute(f.SA1 ? graph.constantOne() : graph.constantZero(), f.ref);
            }
        }
        if (res == CircuitSAT::TIMEOUT) { giveups++; }
        if (res == CircuitSAT::SAT)     { tested++; }
    }
    printf("SAT runtime: %gs\n", timer.getElapsed());
    printf("%d redundancies, %d undetermined, %d tested\n", redun, giveups, tested);
    
    string str ("after");
    graph.dot(str);
}

void
removeUnusedLogic(oaDesign *design)
{
    unsigned numRemoved = 0;
    bool changed = true;
    while (changed) {
        changed = false;

        oaBlock *block = design->getTopBlock();
        oaIter<oaInst> it(block->getInsts());
        oaInst *inst;
        while ((inst = it.getNext())) {
            bool unused = true;

            oaIter<oaInstTerm> it(inst->getInstTerms(oacInstTermIterAll));
            oaInstTerm *iterm;
            while ((iterm = it.getNext())) {
                assertB (iterm->getTerm()) {
                    printf("term %s\n", getName(iterm));
                }
                assert(iterm->getTerm());
                if (iterm->getTerm()->getTermType() == oacOutputTermType) {
                    oaIter<oaInstTerm> it2(iterm->getNet()->getInstTerms(oacInstTermIterAll));
                    oaInstTerm *iterm2;
                    while ((iterm2 = it2.getNext())) {
                        assert(iterm2->getTerm());
                        if (iterm2->getTerm()->getTermType() == oacInputTermType) {
                            unused = false;
                            goto nextInst;
                        }
                    }
                    oaIter<oaTerm> it3(iterm->getNet()->getTerms(oacTermIterAll));
                    oaTerm *iterm3;
                    while ((iterm3 = it3.getNext())) {
                        assert(iterm3);
                        if (iterm3->getTermType() == oacOutputTermType) {
                            unused = false;
                            goto nextInst;
                        }
                    }
                }
            }
            if (unused) {
                //printf("Instance %s is unused. REMOVED\n", getName(inst));
                inst->destroy();
                changed = true;
                numRemoved++;
            }
nextInst:;
        }
    }
    printf("Removed %d unused instances\n", numRemoved);
}
    
}
// vim:et:
