#include "oagRedunCircuitSAT.h"

#define DEBUG(x) 
#ifndef NDEBUG
#define assertB(x) for (;!(x); assert(x)) 
// {}
#else
#define assertB(x) if (0) 
// {}
#endif
#define foreach(i,x) for(typeof((x).begin()) i=(x).begin(); i!=(x).end(); ++i)
using namespace std;

#define check(x) { bool qQqQq = x; assert(qQqQq); }
namespace {
struct myexception : public std::exception {
    const char *s;
    myexception(const char *ss) : s(ss) { }
    virtual const char* what() const throw() {
        return s;
    }
};
}
// kludge to stringify line numbers
#define EXTERNAL_ERROR__(x, l) throw myexception(x " at " __FILE__ ":" #l);
#define EXTERNAL_ERROR_(x, l)  EXTERNAL_ERROR__(x, l)
#define EXTERNAL_ERROR(x)      EXTERNAL_ERROR_(x, __LINE__)

class Ref {
    unsigned _data;
public:
    Ref() : _data(0) { }
    explicit Ref(unsigned i, bool s = false)
        : _data((i<<2) | s) { }

    unsigned id()   const { return _data >> 2; }
    bool     sign() const { return _data & 1;  }

    Ref operator^(bool b) const { Ref l = *this; l._data^=b; return l;}
    Ref operator!() const { return *this ^ true; }
    bool operator==(const Ref l) const { return _data == l._data; }
    bool operator!=(const Ref l) const { return _data != l._data; }
    void print() const {
        printf("%d:%d", id(), sign());
    }
};

class Fanout {
    unsigned _data;
public:
    Fanout() : _data(0) { }
    Fanout(unsigned i, bool c, bool s)
        : _data((i<<2) | (c<<1) | s) { }

    unsigned id()      const { return _data >> 2;       }
    bool     isRight() const { return (_data >> 1) & 1; }
    bool     sign()    const { return _data & 1;        }

    Fanout operator^(bool b) const { Fanout l = *this; l._data^=b; return l;}
    bool operator==(const Fanout l) const { return _data == l._data; }
    bool operator!=(const Fanout l) const { return _data != l._data; }
    void print() const {
        printf("%d:%d:%s", id(), sign(), isRight() ? "R" : "L");
    }
};

struct Node {
    Ref            inputs[2];
    unsigned char  type;
    unsigned char  data;
    unsigned short travId;
    Fanout         start;
    Fanout         next[2];

    enum Type { NONE, AND, TERMINAL, CONSTANT0 };
    void print() const {
        switch (Type(type)) {
            case NONE:
                printf("%s", "NUL");
                break;
            case AND:
                printf("%s", "AND");
                break;
            case TERMINAL:
                printf("%s", "TRM");
                break;
            case CONSTANT0:
                printf("%s", "ZRO");
                break;
        }
        printf(" ("); inputs[0].print(); printf(", "); inputs[1].print(); printf(") ");
        start.print();
        printf(" ("); next[0].print(); printf(", "); next[1].print(); printf(")");
    }
};

class Graph {
    class FanoutProxy {
        const Graph    &graph;
        const unsigned  id;
        struct iterator {
            const Graph &graph;
            Fanout       current;

            void operator++() {
                current = graph._nodes[current.id()].next[current.isRight()];
            }
            Fanout operator*() {
                return current;
            }
            bool operator!=(iterator j) const { return current != j.current; }
            bool operator==(iterator j) const { return current == j.current; }
            iterator(const Graph &g, Fanout c) : graph(g), current(c) { }
        };
    public:
        FanoutProxy(const Graph &g, const unsigned i) : graph(g), id(i) { }
        iterator begin() { return iterator(graph, graph._nodes[id].start); }
        iterator end()   { return iterator(graph, graph.getNullFanout());  }
    };

    vector<Node> _nodes;    
public:
    static Fanout getNullFanout() { return Fanout(); }
    static Ref    getNull()       { return  Ref();   }
    static Ref    constantZero()  { return  Ref(1);  }
    static Ref    constantOne()   { return !Ref(1);  }

    Node::Type getNodeType(unsigned id) const { return static_cast<Node::Type>(_nodes[id].type); }
    Ref        getAndLeft (unsigned id) const { return _nodes[id].inputs[0]; }
    Ref        getAndRight(unsigned id) const { return _nodes[id].inputs[1]; }
    Ref        getAndInput(unsigned id, unsigned lr) const { return _nodes[id].inputs[lr]; }
    Ref        getTerminalDriver(unsigned id) const { return _nodes[id].inputs[0]; }
    FanoutProxy getFanout(unsigned id) const { return FanoutProxy(*this, id); }

    char       getData(unsigned id) const   { return _nodes[id].data; }
    void       setData(unsigned id, char d) { _nodes[id].data = d;    }
    unsigned char&      data(unsigned id)            { return _nodes[id].data; }

    // to make results match with OAGear AIGs
    void       reorderFanouts(unsigned id, const vector<Fanout> &fo);
    void       reorderFanouts(unsigned id, const vector<unsigned> &fo);

    void       setTerminalDriver(unsigned id, Ref r);
    Ref        newTerminal(Ref r);
    Ref        newAnd(Ref l, Ref r);

    Graph();
    void print() const;
    void printBLIF(FILE *fp=stdout) const;
};

void
Graph::print() const
{
    for (unsigned i=0; i<_nodes.size(); i++) {
        printf("%d ", i); _nodes[i].print(); printf("\n");
    }
}

void
Graph::printBLIF(FILE *fp) const
{
    fprintf(fp, ".model foo\n");
    fprintf(fp, ".inputs");
    for (unsigned i=2; i<_nodes.size(); i++) {
        if (getNodeType(i)==Node::TERMINAL and getTerminalDriver(i)==getNull())
            fprintf(fp, " n%d", i);
    }
    fprintf(fp, "\n");
    fprintf(fp, ".outputs");
    for (unsigned i=2; i<_nodes.size(); i++) {
        // XXX: big hack
        if (_nodes[i].data & 0x20)
            fprintf(fp, " n%d", i);
    }
    fprintf(fp, "\n");
    // constant 0
    fprintf(fp, ".names n1\n0\n");
    // everything else
    for (unsigned i=2; i<_nodes.size(); i++) {
        switch (getNodeType(i)) {
            case Node::TERMINAL: {
                Ref l = getTerminalDriver(i);
                if (l == getNull()) continue;
                fprintf(fp, ".names n%d n%d\n", l.id(), i);
                fprintf(fp, "%d 1\n", !l.sign());
                break;
            }
            case Node::AND: {
                Ref l = getAndLeft (i);
                Ref r = getAndRight(i);
                fprintf(fp, ".names n%d n%d n%d\n", l.id(), r.id(), i);
                fprintf(fp, "%d%d 1\n", !l.sign(), !r.sign());
                break;
            }
            default: abort();
        }
    }
}

void
Graph::reorderFanouts(unsigned id, const vector<Fanout> &fo)
{
    Node &n = _nodes[id];
    if (fo.empty()) {
        assert(n.start == Fanout());
    } else {
        n.start = fo[0];
        for (unsigned i=1; i<fo.size(); i++) {
            _nodes[fo[i-1].id()].next[fo[i-1].isRight()] = fo[i];
        }
        _nodes[fo.back().id()].next[fo.back().isRight()] = Fanout();
    }
}

void
Graph::reorderFanouts(unsigned id, const vector<unsigned> &foIds)
{
    Node &n = _nodes[id];

    vector<Fanout> fo;
    for (unsigned i=0; i<foIds.size(); i++) {
        const unsigned foi = foIds[i];
        if (_nodes[foi].inputs[0].id() == id) {
            fo.push_back(Fanout(foi, 0, _nodes[foi].inputs[0].sign()));
        } else {
            assert(_nodes[foi].inputs[1].id() == id);
            fo.push_back(Fanout(foi, 1, _nodes[foi].inputs[1].sign()));
        }
    }
    reorderFanouts(id, fo);
}

void
Graph::setTerminalDriver(unsigned id, Ref l)
{
    //print();
    //printf("settermdrive %d ", id); l.print(); printf("\n");
    Node &n = _nodes[id];
    // remove from old fanout
    if (n.inputs[0] != getNull()) {
        Fanout f = _nodes[n.inputs[0].id()].start;
        if (f == Fanout()) {
            assert(n.inputs[0] == Ref());
        } else if (f.id() == id) {
            _nodes[n.inputs[0].id()].start = n.next[0];
        } else {
            while (1) {
                Fanout next = _nodes[f.id()].next[f.isRight()];
                if (next.id() == id) {
                    _nodes[f.id()].next[f.isRight()] = n.next[0];
                    break;
                }
                f = next;
                assert(f != Fanout());
            }
        }
    }

    n.inputs[0] = l;
    n.next  [0] = _nodes[l.id()].start;
    if (l != getNull()) {
        _nodes[l.id()].start = Fanout(id, 0, l.sign());
    }
    //print();
}

Ref
Graph::newTerminal(Ref l)
{
    unsigned id = _nodes.size();
    _nodes.push_back(Node());
    Node &n = _nodes[id];
    n.type  = Node::TERMINAL;

    n.inputs[0] = l;
    n.next  [0] = _nodes[l.id()].start;
    if (l != getNull())
        _nodes[l.id()].start = Fanout(id, 0, l.sign());

    return Ref(id);
}

Ref
Graph::newAnd(Ref l, Ref r)
{
    unsigned id = _nodes.size();
    _nodes.push_back(Node());
    Node &n = _nodes[id];
    n.type  = Node::AND;

    n.inputs[0] = l;
    n.next  [0] = _nodes[l.id()].start;
    if (l != getNull())
        _nodes[l.id()].start = Fanout(id, 0, l.sign());

    n.inputs[1] = r;
    n.next  [1] = _nodes[r.id()].start;
    if (r != getNull())
        _nodes[r.id()].start = Fanout(id, 1, r.sign());

    return Ref(id);
}

Graph::Graph()
    : _nodes(2)
{
    assert(getNull().id()      == 0);
    assert(constantZero().id() == 1);
    _nodes[0].type = Node::NONE;
    _nodes[1].type = Node::CONSTANT0;
}

/// Literals
class Lit {
public:
    unsigned _data;
public:
    Lit(unsigned i, bool c, bool s = false)
        : _data((i<<2) | (c<<1) | s) { }

    unsigned id ()  const { return _data >> 2;       }
    bool     ckt()  const { return (_data >> 1) & 1; }
    unsigned var()  const { return _data >> 1;       }
    bool     sign() const { return _data & 1;        }

    bool operator==(const Lit l) const { return _data == l._data; }
    bool operator!=(const Lit l) const { return _data != l._data; }
    Lit  abs() const { Lit l = *this; l._data&=~1; return l;}
    Lit  operator^(bool b) const { Lit l = *this; l._data^=b; return l;}
    Lit  operator!() const { return *this ^ true; }
    void print(FILE *fp=stdout) const {
#ifdef COLOR
        if (sign()) {
            fprintf(fp, "%d:%d:\033[31m0\033[0m", id(), ckt());
        } else {
            fprintf(fp, "%d:%d:\033[32m1\033[0m", id(), ckt());
        }
#else
        if (sign()) {
            fprintf(fp, "%d:%d:0", id(), ckt());
        } else {
            fprintf(fp, "%d:%d:1", id(), ckt());
        }
#endif
    }

};
const Lit lit_Undef(UINT_MAX, 0);
const Lit lit_Error(UINT_MAX, 1);


/// 0,1,X values
class lbool {
    char value;
    explicit lbool(char v) : value(v) { }
public:
    lbool()       : value(3) { }
    lbool(bool x) : value(x) { }
    int toInt(void) const { return value; }
    operator int() const { return value; }
    bool  operator==(lbool b) const { return value == b.value; }
    bool  operator!=(lbool b) const { return value != b.value; }
    lbool operator^ (bool  b) const { 
        lbool res;
        res.value = ((value^b) | (value>>1));
        return res;
    }
    void print(FILE *fp=stdout) const {
#ifdef COLOR
        if (value == 0) {
            fprintf(fp, "\033[31m0\033[0m");
        } else if (value == 1) {
            fprintf(fp, "\033[32m1\033[0m");
        } else {
            fprintf(fp, "\033[33mX\033[0m");
        }
#else
        if (value == 0) {
            fprintf(fp, "0");
        } else if (value == 1) {
            fprintf(fp, "1");
        } else {
            fprintf(fp, "X");
        }
#endif
    }

    friend char  toInt  (lbool l);
    friend lbool toLbool(char v); 
};
inline char  toInt  (lbool l) { return l.toInt(); }
inline lbool toLbool(char v)  { return lbool(v);  }

const lbool l_True  = toLbool(0x1);
const lbool l_False = toLbool(0x0);
const lbool l_Undef = toLbool(0x3);

namespace oagRedun {

/// The solver
class CircuitSATxyz : public CircuitSAT {
    /// AI graph
    oagAi::Graph      &_oagGraph;
    unsigned           _dirtyMarker;
    // internal graph
    Graph              _graph;
    // maps node IDs between two graphs
    vector<unsigned>   _int2ext;
    vector<unsigned>   _ext2int;
    void createGraphRec(oagAi::Ref);
    void syncFanouts(unsigned id);
    void setTerminalDriver(oagAi::Ref, oagAi::Ref);
    Ref  oag2xyz(oagAi::Ref r) const { 
        //assert(oagAi::Graph::getNodeID(r) < _ext2int.size());
        //assert(_ext2int[oagAi::Graph::getNodeID(r)] != UINT_MAX);
        return Ref(_ext2int[oagAi::Graph::getNodeID(r)], oagAi::Graph::isInverted(r)); }
    oagAi::Ref xyz2oag(Ref r) const { 
        //assert(r.id() < _int2ext.size());
        //assert(_int2ext[r.id()] != UINT_MAX);
        return oagAi::Graph::getRefFromID(_int2ext[r.id()], r.sign()); }
    Lit oag2xyz(Lit r) const { 
        DEBUG(if (r.id() == lit_Undef.id()) return r);
        return Lit(_ext2int[r.id()], r.ckt(), r.sign()); }
    Lit xyz2oag(Lit r) const { 
        DEBUG(if (r.id() == lit_Undef.id()) return r);
        return Lit(_int2ext[r.id()], r.ckt(), r.sign()); }
    // dominators
    bool                    _inDFSorder;
    vector<unsigned>        _dominators;
    void computeDominatorsDC();

    static Ref id2ref(unsigned id) { return Ref(id); }
    static unsigned ref2id(Ref rf) { return rf.id(); }

    static Ref lit2ref(Lit l)         { return Ref(l.id(), l.sign()); }
    static Lit ref2lit(Ref r, bool c) { return Lit(r.id(), c, r.sign());}

    unsigned getData(unsigned id) const       { return _graph.getData(id); }
    void     setData(unsigned id, unsigned x) { _graph.setData(id, x);     }

    void setVisited(unsigned id, bool);
    bool isVisited (unsigned id);

    void setCO(unsigned id, bool);
    bool isCO (unsigned id);

    /// fault cone
    vector<unsigned>        _faultCone;
    void setDcone (unsigned id, bool);
    bool isInDcone(unsigned id);
    void markDcone(unsigned id, bool);
    bool isDDominator(unsigned id);
    void setDDominator(unsigned id, bool);

    /// SAT stuff
    lbool getValue(Lit);
    void  setValue(unsigned id, bool ckt, lbool);
    bool  enqueue (Lit);
    void  printValue(unsigned id);

    // implication queue and justification queue
    unsigned         _qhead;
    vector<Lit>      _trail;
    vector<Lit>      _jFront;
    bool             _reachedPO;
    struct TimeOut : public exception { };

    bool implyI(Lit);
    bool implyI(Fanout fo, bool ckt);
    bool implyO(Lit);

    Lit  decide();
    bool propagate();
    bool solve(Lit p);

    struct {
        unsigned numPropagations;
    } stats;

    oagAi::Ref _faultSite;
    Result isSubstitutableH(oagAi::Ref src, oagAi::Ref old, oagAi::Ref tgt);

    void printGraphs() const;
public:
    CircuitSATxyz(oagAi::Graph &g, const vector<oagAi::Ref> &o, bool doms);

    virtual Result solve(const std::vector<oagAi::Ref> &assumps);

    virtual Result isSubstitutable(oagAi::Ref src, oagAi::Ref tgt);
    virtual void   substitute     (oagAi::Ref src, oagAi::Ref tgt);
};

void
CircuitSATxyz::printGraphs() const
{
    _oagGraph.print();
    for (unsigned i=0; i<_int2ext.size(); i++) {
        printf("%d %d\n", i, _int2ext[i]);
    }
    _graph.print();
}

/// Graph dominator algorithm
/// (as described in Cooper et al., "A Simple, Fast Dominance Algorithm")
void
CircuitSATxyz::computeDominatorsDC() 
{
    assert(_dominators.size() == _int2ext.size());
    // sort nodes in reverse DFS
    std::sort(_faultCone.begin(), _faultCone.end(), greater<unsigned>());

    foreach (it, _faultCone) {
        unsigned i      = *it;
        // while nodes without fanout are dominated by UINT_MAX
        _dominators[i] = UINT_MAX;

        unsigned newdom = i;
        foreach (it, _graph.getFanout(i)) {
            if (newdom == i) {
                // first fanout
                newdom = (*it).id();
            } else {
                // all others, do intersect
                unsigned finger1 = (*it).id();
                unsigned finger2 = newdom;
                while (finger1 != finger2) {
                    while (finger1 < finger2) {
                        assert(finger1 != UINT_MAX);
                        finger1 = _dominators[finger1];
                    }
                    while (finger1 > finger2) {
                        assert(finger2 != UINT_MAX);
                        finger2 = _dominators[finger2];
                    }
                }
                newdom = finger1;
            }
            _dominators[i] = newdom;
        }
    }
}

void
CircuitSATxyz::createGraphRec(oagAi::Ref ref)
{
    if (_oagGraph.isVisited(ref)) return;
    _oagGraph.markVisited(ref);

    Ref intO;
    switch (_oagGraph.getNodeType(ref)) {
        case oagAi::Node::AND: {
            oagAi::Ref oagL = _oagGraph.getAndLeft (ref);
            oagAi::Ref oagR = _oagGraph.getAndRight(ref);
            createGraphRec(oagL);
            createGraphRec(oagR);

            Ref intL = oag2xyz(oagL);
            Ref intR = oag2xyz(oagR);
            intO = _graph.newAnd(intL, intR);
            break;
        } 
        case oagAi::Node::TERMINAL: {
            oagAi::Ref oagL = _oagGraph.getTerminalDriver(ref);
            assert(_oagGraph.getNodeID(oagL) != _oagGraph.getNodeID(ref));
            createGraphRec(oagL);

            Ref intL = oag2xyz(oagL);
            intO = _graph.newTerminal(intL);
            break;
        }
        case oagAi::Node::NONE:
            intO = _graph.getNull();
            break;
        case oagAi::Node::CONSTANT0:
            intO = _graph.constantZero();
            break;
        case oagAi::Node::SEQUENTIAL:
            abort();
    }
    unsigned eid = _oagGraph.getNodeID(ref);
    unsigned iid = intO.id();

    if (_ext2int.size() < eid+1) 
        _ext2int.resize(eid+1, UINT_MAX);
    _ext2int[eid] = iid;
    if (_int2ext.size() < iid+1) 
        _int2ext.resize(iid+1, UINT_MAX);
    _int2ext[iid] = eid;

    setVisited(iid, false);
    setCO   (iid, false);
    setDcone(iid, false);
    setDDominator(iid, false);
    setValue(iid, 0, l_Undef);
}

void
CircuitSATxyz::syncFanouts(unsigned id)
{
    const oagAi::Graph &g   = _oagGraph;
    const oagAi::Ref    ref = xyz2oag(Ref(id));

    vector<unsigned> newFos;
    foreach (it, _oagGraph.getFanout(ref)) {
        unsigned   foId = _oagGraph.getNodeID(*it);
        newFos.push_back(_ext2int[foId]);
    }
    _graph.reorderFanouts(id, newFos);
}

void
CircuitSATxyz::setTerminalDriver(oagAi::Ref tgt, oagAi::Ref src)
{
    Ref itgt = oag2xyz(tgt);
    Ref isrc = oag2xyz(src);
    Ref oldsrc = _graph.getTerminalDriver(itgt.id());
    if (isrc.id() > itgt.id()) {
        _inDFSorder = false;
    }
    _graph.setTerminalDriver(itgt.id(), isrc);
    _oagGraph.setTerminalDriver(tgt, src);

    syncFanouts(oldsrc.id());
    syncFanouts(  isrc.id());
}

CircuitSATxyz::CircuitSATxyz(oagAi::Graph &g, const vector<oagAi::Ref> &o, bool doms)
    : _oagGraph(g), _inDFSorder(doms)
{
    _dirtyMarker = _oagGraph.getDirtyMarker();

    // create internal version of graph
    g.newTraversalID();
    createGraphRec(_oagGraph.getNull());
    createGraphRec(_oagGraph.constantOne());
    foreach (it, o) {
        createGraphRec(*it);
    }
    foreach (it, o) {
        setCO(oag2xyz(*it).id(), true);
    }
    assert(xyz2oag(_graph.constantOne()) == _oagGraph.constantOne());
    assert(oag2xyz(_oagGraph.constantOne()) == _graph.constantOne());
    assert(xyz2oag(_graph.getNull()) == _oagGraph.getNull());
    assert(oag2xyz(_oagGraph.getNull()) == _graph.getNull());

    for (unsigned i=0; i<_int2ext.size(); i++) {
        assert(_int2ext[i] != UINT_MAX);
        syncFanouts(i);
    }
    _dominators.resize(_int2ext.size());
    DEBUG(printGraphs());
    //_graph.printBLIF();
    //exit(0);
    //DEBUG(printGraphs());
}

void 
CircuitSATxyz::setDDominator(unsigned id, bool b)
{
    unsigned u = getData(id);
    u &= ~0x80;
    u |= (b << 7);
    setData(id, u);
}

bool 
CircuitSATxyz::isDDominator(unsigned id)
{
    return getData(id) & 0x80;
}

void 
CircuitSATxyz::setVisited(unsigned id, bool b)
{
    unsigned u = getData(id);
    u &= ~0x40;
    u |= (b << 6);
    setData(id, u);
}

bool 
CircuitSATxyz::isVisited(unsigned id)
{
    return getData(id) & 0x40;
}

void 
CircuitSATxyz::setCO(unsigned id, bool b)
{
    unsigned u = getData(id);
    u &= ~0x20;
    u |= (b << 5);
    setData(id, u);
}

bool 
CircuitSATxyz::isCO(unsigned id)
{
    return getData(id) & 0x20;
}

void 
CircuitSATxyz::setDcone(unsigned id, bool b)
{
    unsigned u = getData(id);
    u &= ~0x10;
    u |= (b << 4);
    setData(id, u);
}

bool 
CircuitSATxyz::isInDcone(unsigned id)
{
    return getData(id) & 0x10;
}

void 
CircuitSATxyz::markDcone(unsigned id, bool b)
{
    if (b) {
        assert(_faultCone.empty());
        unsigned fhead = 0;
        setDcone(id, true);
        _faultCone.push_back(id);
        while (fhead < _faultCone.size()) {
            unsigned id = _faultCone[fhead++];
            foreach (it, _graph.getFanout(id)) {
                unsigned foid = (*it).id();
                // already marked?
                if (isInDcone(foid)) continue;
                setDcone(foid, true);
                _faultCone.push_back(foid);
            }
            if (fhead > params.maxNumPropagations) throw TimeOut();
        }
        if (_inDFSorder) {
            computeDominatorsDC();
            unsigned d = id;
            while (d!=UINT_MAX) {
                setDDominator(d, b);
                d = _dominators[d];
            }
        }
    } else {
        foreach (it, _faultCone) {
            setDcone(*it, false);
            setDDominator(*it, false);
        }
        _faultCone.clear();
    }
}

lbool
CircuitSATxyz::getValue(Lit l)
{
    unsigned u = getData(l.id());
    char     c = (l.ckt() ? u>>2 : u) & 0x3;
    lbool res = toLbool(c) ^ l.sign();
    return res;
}

void
CircuitSATxyz::printValue(unsigned id) 
{
    lbool v0 = getValue(Lit(id, 0));
    lbool v1 = getValue(Lit(id, 1));
    printf("%d:", _int2ext[id]);
    v0.print();
    printf("/");
    v1.print();
}

void
CircuitSATxyz::setValue(unsigned id, bool ckt, lbool val)
{
    //printf("setval %d %d %d\n", id, ckt, int(val));
    unsigned u = getData(id);
    if (!isInDcone(id)) {
        u &= ~0xF;
        u |= (toInt(val)<<2) | toInt(val);
    } else {
        u &= ~(ckt ? 0x3<<2 : 0x3);
        u |= ckt ? toInt(val)<<2 : toInt(val);
    }
    setData(id, u);
}

bool
CircuitSATxyz::enqueue(Lit l)
{
    lbool v = getValue(l);
    DEBUG(printf("enqueue "); if (isInDcone(l.id())) printf("DC"); xyz2oag(l).print(); printf(" (%d)\n", int(v)));
    if (v == l_True ) return true;
    if (v == l_False) return false;

    setValue(l.id(), l.ckt(), !l.sign());
    _trail.push_back(l);
    if (isDDominator(l.id())) {
        //printf("aha! %d\n", _int2ext[l.id()]);
        return enqueue(Lit(l.id(), !l.ckt(), !l.sign()));
    }
    return true;
}

bool
CircuitSATxyz::propagate()
{
    while (_qhead < _trail.size()) {
        Lit p = _trail[_qhead++];
        
        if (!implyO(p)) return false;

        if (!isInDcone(p.id())) {
            foreach (it, _graph.getFanout(p.id())) {
                Fanout fo = *it;
                DEBUG(printf("%s%d -> %d\n", fo.sign()?"~":"", _int2ext[p.id()], _int2ext[fo.id()]));
                if (!implyI(fo ^ p.sign(), p.ckt()))
                    return false;
                if (isInDcone(fo.id())) {
                    //if (!implyI(ref2lit(ref, !p.ckt()) ^ p.sign()))
                    if (!implyI(fo ^ p.sign(), !p.ckt()))
                        return false;
                }
            }
        } else {
            foreach (it, _graph.getFanout(p.id())) {
                Fanout fo = *it;
                DEBUG(printf("%s%d -> %d\n", fo.sign()?"~":"", _int2ext[p.id()], _int2ext[fo.id()]));
                if (!implyI(fo ^ p.sign(), p.ckt()))
                    return false;
            }
        }
    }
    return true;
}

#if 0
bool
CircuitSATxyz::implyI(Lit l)
{
    DEBUG(printf("implyI "); l.print(); printf("\n"));
    Ref               ref  = lit2ref(l);
    oagAi::Node::Type type = _graph.getNodeType(ref);

    if (type == oagAi::Node::AND) {
        Lit lt = ref2lit(_graph.getAndLeft (ref), l.ckt());
        Lit rt = ref2lit(_graph.getAndRight(ref), l.ckt());

        if (!l.sign()) {
            // 1 input
            // (~l + ~r + o)
            if (getValue(lt)==l_True and getValue(rt)==l_True)
                return enqueue(l);
            if (getValue(ref2lit(ref, l.ckt())) == l_False) {
                if (getValue(lt)==l_True) return enqueue(!rt);
                if (getValue(rt)==l_True) return enqueue(!lt);
            }
        } else {
            // 0 input
            // (l + ~o)
            // (r + ~o)
            return enqueue(l);
        }
    } else {
        // D value reached output?
        if (isCO(l.id())) {
            DEBUG(printf("PO: "); printValue(l.id()); printf("\n"));
            if (getValue(Lit(l.id(), !l.ckt(), l.sign())) == l_False)  {
                DEBUG(printf("reached PO %d\n", l.id()));
                _reachedPO = true;
            }
        }
        // buffer and/or output
        return enqueue(l);
    }
    return true;
}
#endif
bool
CircuitSATxyz::implyI(Fanout fo, bool ckt)
{
    if (stats.numPropagations++ > params.maxNumPropagations)
        throw TimeOut();

    DEBUG(printf("implyI "); xyz2oag(Lit(fo.id(), ckt, fo.sign())).print(); printf("\n"));
    Node::Type type = _graph.getNodeType(fo.id());

    if (type == Node::AND) {
        if (!fo.sign()) {
            // 1 input
            // (~l + ~r + o)
            unsigned lr = fo.isRight();
            Lit other = ref2lit(_graph.getAndInput(fo.id(), !lr), ckt);
            if (getValue(other)==l_True)
                return enqueue(Lit(fo.id(), ckt));
            if (getValue(Lit(fo.id(), ckt)) == l_False) {
                return enqueue(!other);
            }
            return true;
        } else {
            // 0 input
            // (l + ~o)
            // (r + ~o)
            return enqueue(!Lit(fo.id(), ckt));
        }
    } else {
        // D value reached output?
        if (isCO(fo.id())) {
            DEBUG(printf("PO: "); printValue(fo.id()); printf("\n"));
            if (getValue(Lit(fo.id(), !ckt, fo.sign())) == l_False)  {
                DEBUG(printf("reached PO %d\n", _int2ext[fo.id()]));
                _reachedPO = true;
            }
        }
        // buffer and/or output
        return enqueue(Lit(fo.id(), ckt, fo.sign()));
    }
    return true;
}

bool
CircuitSATxyz::implyO(Lit l)
{
    DEBUG(printf("implyO "); xyz2oag(l).print(); printf("\n"));
    Ref           ref  = lit2ref(l);
    Node::Type type = _graph.getNodeType(l.id());

    if (type == Node::AND) {
        Lit lt = ref2lit(_graph.getAndLeft (l.id()), l.ckt());
        Lit rt = ref2lit(_graph.getAndRight(l.id()), l.ckt());

        if (!l.sign()) {
            // 1 output
            // (l + ~o)
            // (r + ~o)
            return enqueue(lt) and enqueue(rt);
        } else {
            // 0 output
            // (~l + ~r + o)
            lbool lv = getValue(lt);
            if (lv == l_False) {
                return true;
            } else if (lv == l_True) {
                return enqueue(!rt);
            }

            lbool rv = getValue(rt);
            if (rv == l_False) {
                return true;
            } else if (rv == l_True) {
                return enqueue(!lt);
            }

            assert(lv==l_Undef and rv==l_Undef);
            _jFront.push_back(l);
            return true;
        }
    } else if (type == Node::TERMINAL) {
        Ref inp = _graph.getTerminalDriver(l.id());

        if (inp == _graph.getNull()) {
            // PI
            return true;
        } else {
            // buffer
            return enqueue(ref2lit(inp, l.ckt()) ^ l.sign());
        }
    } else {
        assert(type == Node::CONSTANT0);
        return (getValue(ref2lit(_graph.constantOne(), 0)) == l_True);
    }
}

Lit
CircuitSATxyz::decide()
{
    if (!_reachedPO) {
        // still need to reach PO
        DEBUG(printf("trying to reach PO\n"));
        vector<Ref> worklist(1, oag2xyz(_faultSite));
        vector<unsigned> tounmark;
        Lit decision = lit_Error;
        while (!worklist.empty()) {
            Ref r = worklist.back(); worklist.pop_back();
            if (isVisited(r.id())) continue;
            setVisited(r.id(), true);
            tounmark.push_back(r.id());
            //if (!visited.insert(r.id()).second) continue;

            assert(isInDcone(ref2id(r)));
            DEBUG(printf("checking to reach PO through "); printValue(ref2id(r)); printf("\n"));
            // can do something here?
            if (getValue(ref2lit(r, 0)) == l_Undef) { decision = ref2lit(r, 0); break; }
            if (getValue(ref2lit(r, 1)) == l_Undef) { decision = ref2lit(r, 1); break; }
            foreach (it, _graph.getFanout(r.id()))
                worklist.push_back(Ref((*it).id()));
        }
        foreach (it, tounmark)
            setVisited(*it, false);
        // can't reach PO
        if (decision==lit_Error)
            DEBUG(printf("couldn't reach PO!\n"));
        return decision;
        //return lit_Error;
    } else {
        DEBUG(printf("trying to justify\n"));
        // try to justify something
        foreach (it, _jFront) {
            Lit l   = *it;
            Lit lt  = ref2lit(_graph.getAndLeft (l.id()), l.ckt());
            Lit rt  = ref2lit(_graph.getAndRight(l.id()), l.ckt());
            if (getValue(lt) == l_False) continue;
            if (getValue(rt) == l_False) continue;
            assertB (getValue(l.abs())==l_False
                    and getValue(lt)!=l_True
                    and getValue(rt)!=l_True) {
                printValue(l.id());  printf(" (");
                printValue(lt.id()); printf(" ");
                printValue(rt.id()); printf(")\n");
            }
            return lt;
        }
        // nothing to justify
        DEBUG(printf("nothing to justify!\n"));
        return lit_Undef;
    }
}

bool
CircuitSATxyz::solve(Lit p)
{
    unsigned tlim = _trail.size();
    unsigned jlim = _jFront.size();
    bool     rlim = _reachedPO;

    if (enqueue(p) and propagate()) {
        Lit v = decide();

        DEBUG(printf("decided "); xyz2oag(v).print(); printf("\n"));

        if (v == lit_Undef) return true;
        if (v != lit_Error) {
            DEBUG(printf("branch %d 0\n", xyz2oag(v).id()));
            if (solve( v)) return true;
            DEBUG(printf("branch %d 1\n", xyz2oag(v).id()));
            if (solve(!v)) return true;
        }
    }

    _reachedPO = rlim;
    while (_jFront.size() != jlim) {
        _jFront.pop_back();
    }
    while (_trail.size() != tlim) {
        setValue(_trail.back().id(), _trail.back().ckt(), l_Undef);
        _trail.pop_back();
    }
    _qhead = tlim;

    return false;
}


CircuitSAT::Result 
CircuitSATxyz::solve(const vector<oagAi::Ref> &assumps)
{
    assert(_oagGraph.getDirtyMarker() == _dirtyMarker);
    if (assumps.empty()) return SAT;
    _reachedPO = true;
    _qhead     = _trail.size();

    return TIMEOUT;
}

CircuitSAT::Result
CircuitSATxyz::isSubstitutableH(oagAi::Ref src, oagAi::Ref old, oagAi::Ref tgt)
{
    DEBUG(printf("CHECK HALF!\n"));
    assert(_jFront.empty());
    unsigned tlim = _trail.size();
    _qhead     = tlim;
    _reachedPO = false;
    enqueue(!ref2lit(oag2xyz(old), 0));
    enqueue(!ref2lit(oag2xyz(tgt), 0));
    enqueue( ref2lit(oag2xyz(tgt), 1));

    Result res;

    try {
        res = solve( ref2lit(oag2xyz(src), 0)) ? SAT : UNSAT;
    } catch (TimeOut) {
        res = TIMEOUT;
    }

    while (!_jFront.empty()) {
        _jFront.pop_back();
    }
    while (_trail.size() != tlim) {
        setValue(_trail.back().id(), _trail.back().ckt(), l_Undef);
        _trail.pop_back();
    }

    return res;
}

CircuitSAT::Result 
CircuitSATxyz::isSubstitutable(oagAi::Ref src, oagAi::Ref tgt)
{
    DEBUG(printf("checking to sub %d for %d\n", src, tgt));
    assert(_oagGraph.getDirtyMarker() == _dirtyMarker);

    Result res;
    if (_oagGraph.getNodeType(tgt) == oagAi::Node::TERMINAL) {
        oagAi::Ref ref0 = _oagGraph.getTerminalDriver(tgt);
        setTerminalDriver(tgt, _oagGraph.getNull());

        stats.numPropagations = 0;
        _faultSite = tgt;
        try {
            markDcone(oag2xyz(tgt).id(), true);

            assert(_trail.empty());
            _qhead = 0;
            check(enqueue(ref2lit(_graph.constantOne(), 0)));
            //check(propagate());
            if (!propagate()) {
                res = UNSAT;
            } else {

            res = isSubstitutableH(src, ref0, tgt);
            if (res == UNSAT) {
                res = isSubstitutableH(_oagGraph.notOf(src), _oagGraph.notOf(ref0), _oagGraph.notOf(tgt));
            }
            }
        } catch (TimeOut) {
            res = TIMEOUT;
        }
        while (!_trail.empty()) {
            setValue(_trail.back().id(), _trail.back().ckt(), l_Undef);
            _trail.pop_back();
        }

        markDcone(oag2xyz(tgt).id(), false);
        setTerminalDriver(tgt, ref0);
        DEBUG(printf("%d propagations\n", stats.numPropagations));
    } else {
        // TODO: implement me
        res = TIMEOUT;
    }
    _dirtyMarker = _oagGraph.getDirtyMarker();
    DEBUG(printf("subst result %d\n", res));
    return res;
}

void   
CircuitSATxyz::substitute(oagAi::Ref src, oagAi::Ref tgt)
{
    assert(_oagGraph.getDirtyMarker() == _dirtyMarker);
    if (_oagGraph.getNodeType(tgt) == oagAi::Node::TERMINAL) {
        //FILE *fp1 = fopen("ckt1.blif", "w");
        //FILE *fp2 = fopen("ckt2.blif", "w");
        //_graph.printBLIF(fp1);
        setTerminalDriver(tgt, src);
        //_graph.printBLIF(fp2);
        //fclose(fp1);
        //fclose(fp2);
        //system("abc -c \"unset progressbar; cec ckt1.blif ckt2.blif\"");
    } else {
        assert("not implemented yet!\n" && 0);
    }
    _dirtyMarker = _oagGraph.getDirtyMarker();
}

CircuitSAT* 
CircuitSAT::createXYZ(oagAi::Graph &g, const std::vector<oagAi::Ref> &o, bool doms)
{
    return new CircuitSATxyz(g, o, doms);
}

}
// vim:set et:
