#include "oagRedunCircuitSAT.h"

#define DEBUG(x) 
#ifndef NDEBUG
#define assertB(x) for (;!(x); assert(x)) 
// {}
#else
#define assertB(x) if (0) 
// {}
#endif
#define foreach(i,x) for(typeof((x).begin()) i=(x).begin(); i!=(x).end(); ++i)
using namespace std;

#define check(x) { bool qQqQq = x; assert(qQqQq); }

namespace {

//void check(bool b) { assert(b); }

/// Literals
class Lit {
public:
    unsigned _data;
public:
    Lit(unsigned i, bool c, bool s = false)
        : _data((i<<2) | (c<<1) | s) { }

    unsigned id ()  const { return _data >> 2;       }
    bool     ckt()  const { return (_data >> 1) & 1; }
    unsigned var()  const { return _data >> 1;       }
    bool     sign() const { return _data & 1;        }

    bool operator==(const Lit l) const { return _data == l._data; }
    bool operator!=(const Lit l) const { return _data != l._data; }
    Lit  abs() const { Lit l = *this; l._data&=~1; return l;}
    Lit  operator^(bool b) const { Lit l = *this; l._data^=b; return l;}
    Lit  operator!() const { return *this ^ true; }
    void print(FILE *fp=stdout) const {
#ifdef COLOR
        if (sign()) {
            fprintf(fp, "%d:%d:\033[31m0\033[0m", id(), ckt());
        } else {
            fprintf(fp, "%d:%d:\033[32m1\033[0m", id(), ckt());
        }
#else
        if (sign()) {
            fprintf(fp, "%d:%d:0", id(), ckt());
        } else {
            fprintf(fp, "%d:%d:1", id(), ckt());
        }
#endif
    }

};
const Lit lit_Undef(UINT_MAX, 0);
const Lit lit_Error(UINT_MAX, 1);


/// 0,1,X values
class lbool {
    char value;
    explicit lbool(char v) : value(v) { }
public:
    lbool()       : value(3) { }
    lbool(bool x) : value(x) { }
    int toInt(void) const { return value; }
    operator int() const { return value; }
    bool  operator==(lbool b) const { return value == b.value; }
    bool  operator!=(lbool b) const { return value != b.value; }
    lbool operator^ (bool  b) const { 
        lbool res;
        res.value = ((value^b) | (value>>1));
        return res;
    }
    void print(FILE *fp=stdout) const {
#ifdef COLOR
        if (value == 0) {
            fprintf(fp, "\033[31m0\033[0m");
        } else if (value == 1) {
            fprintf(fp, "\033[32m1\033[0m");
        } else {
            fprintf(fp, "\033[33mX\033[0m");
        }
#else
        if (value == 0) {
            fprintf(fp, "0");
        } else if (value == 1) {
            fprintf(fp, "1");
        } else {
            fprintf(fp, "X");
        }
#endif
    }

    friend char  toInt  (lbool l);
    friend lbool toLbool(char v); 
};
inline char  toInt  (lbool l) { return l.toInt(); }
inline lbool toLbool(char v)  { return lbool(v);  }

const lbool l_True  = toLbool(0x1);
const lbool l_False = toLbool(0x0);
const lbool l_Undef = toLbool(0x3);

class Fanouts {
    const oagAi::Graph &_g;
    const oagAi::Ref    _ref;

    class iterator {
        const oagAi::Graph               &_g;
        const oagAi::Ref                  _ref;
        list<oagAi::Ref>::const_iterator  _i;
    public:
        iterator(const oagAi::Graph &g, oagAi::Ref r, list<oagAi::Ref>::const_iterator i)
            : _g(g), _ref(r), _i(i) { }
        void operator++() { ++_i; }
        oagAi::Ref operator*() {
            oagAi::Ref r = *_i;
            // figure out sign bit
            if (_g.getNodeType(r) == oagAi::Node::AND) {
                if (_g.getNonInverted(_g.getAndLeft(r)) == _g.getNonInverted(_ref))
                    return _g.notCondOf(r, _g.isInverted(_g.getAndLeft(r)));
                if (_g.getNonInverted(_g.getAndRight(r)) == _g.getNonInverted(_ref))
                    return _g.notCondOf(r, _g.isInverted(_g.getAndRight(r)));
                assert(0);
                abort();
            } else {
                assert(_g.getNodeType(r) == oagAi::Node::TERMINAL);
                assert(_g.getNonInverted(_g.getTerminalDriver(r)) == _g.getNonInverted(_ref));
                return _g.notCondOf(r, _g.isInverted(_g.getTerminalDriver(r)));
            }
        }
        bool operator!=(iterator j) const { return _i != j._i; }
        bool operator==(iterator j) const { return _i == j._i; }
    };
public:
    Fanouts(const oagAi::Graph &g, oagAi::Ref r) : _g(g), _ref(r) { }
    iterator begin() { return iterator(_g, _ref, _g.getFanout(_ref).begin()); }
    iterator end()   { return iterator(_g, _ref, _g.getFanout(_ref).end());   }
};

}

namespace oagRedun {

/// The solver
class CircuitSAToag : public CircuitSAT {
    typedef oagAi::Graph    Graph;
    typedef oagAi::Ref      Ref;

    static Ref id2ref(unsigned id) { return Graph::getRefFromID(id); }
    static unsigned ref2id(Ref rf) { return Graph::getNodeID(rf);    }

    static Ref lit2ref(Lit l)         { return Graph::getRefFromID(l.id(), l.sign()); }
    static Lit ref2lit(Ref r, bool c) { return Lit(Graph::getNodeID(r), c, Graph::isInverted(r)); }

    /// AI graph
    Graph             &_graph;
    unsigned           _dirtyMarker;
    unsigned getData(unsigned id) const { return _graph.getUserData(id2ref(id), 0); }
    void setData(unsigned id, unsigned x) { _graph.setUserData(id2ref(id), 0, x); }

    void setCO(unsigned id, bool);
    bool isCO (unsigned id);

    void clearDataRec(Ref);

    /// fault cone
    vector<unsigned>        _faultCone;
    void setDcone (unsigned id, bool);
    bool isInDcone(unsigned id);
    void markDcone(unsigned id, bool);

    /// SAT stuff
    lbool getValue(Lit);
    void  setValue(unsigned id, bool ckt, lbool);
    bool  enqueue (Lit);
    void  printValue(unsigned id);

    // implication queue and justification queue
    unsigned         _qhead;
    vector<Lit>      _trail;
    vector<Lit>      _jFront;
    bool             _reachedPO;
    struct TimeOut : public exception { };

    bool implyI(Lit);
    bool implyO(Lit);

    Lit  decide();
    bool propagate();
    bool solve(Lit p);

    struct {
        unsigned numPropagations;
    } stats;

    Ref _faultSite;
    Result isSubstitutableH(Ref src, Ref old, Ref tgt);
public:
    CircuitSAToag(Graph &g, const vector<Ref> &o);

    virtual Result solve(const std::vector<Ref> &assumps);

    virtual Result isSubstitutable(Ref src, Ref tgt);
    virtual void   substitute     (Ref src, Ref tgt);
};

void
CircuitSAToag::clearDataRec(Ref ref)
{
    if (_graph.isVisited(ref)) return;
    _graph.markVisited(ref);

    unsigned id = ref2id(ref);
    setCO   (id, false);
    setDcone(id, false);
    setValue(id, 0, l_Undef);

    foreach (it, _graph.getFanout(ref))
        clearDataRec(*it);

    switch (_graph.getNodeType(ref)) {
        case oagAi::Node::AND:
            clearDataRec(_graph.getAndLeft (ref));
            clearDataRec(_graph.getAndRight(ref));
            return;
        case oagAi::Node::TERMINAL: 
            clearDataRec(_graph.getTerminalDriver(ref));
            return;
        case oagAi::Node::NONE:
        case oagAi::Node::SEQUENTIAL:
        case oagAi::Node::CONSTANT0:
            return;
    }
}

CircuitSAToag::CircuitSAToag(Graph &g, const vector<Ref> &o)
    : _graph(g)
{
    _dirtyMarker = _graph.getDirtyMarker();
    g.newTraversalID();
    foreach (it, o) {
        clearDataRec(*it);
    }
    foreach (it, o) {
        setCO(ref2id(*it), true);
    }
    clearDataRec(g.constantOne());
    Lit l = ref2lit(g.constantOne(), 0);
    //setValue(l.id(), 0, !l.sign());
}

void 
CircuitSAToag::setCO(unsigned id, bool b)
{
    unsigned u = getData(id);
    u &= ~0x20;
    u |= (b << 5);
    setData(id, u);
}

bool 
CircuitSAToag::isCO(unsigned id)
{
    return getData(id) & 0x20;
}

void 
CircuitSAToag::setDcone(unsigned id, bool b)
{
    unsigned u = getData(id);
    u &= ~0x10;
    u |= (b << 4);
    setData(id, u);
}

bool 
CircuitSAToag::isInDcone(unsigned id)
{
    return getData(id) & 0x10;
}

void 
CircuitSAToag::markDcone(unsigned id, bool b)
{
    if (b) {
        assert(_faultCone.empty());
        unsigned fhead = 0;
        setDcone(id, true);
        _faultCone.push_back(id);
        while (fhead < _faultCone.size()) {
            unsigned id = _faultCone[fhead++];
            foreach (it, _graph.getFanout(id2ref(id))) {
                unsigned foid = ref2id(*it);
                // already marked?
                if (isInDcone(foid)) continue;
                setDcone(foid, true);
                _faultCone.push_back(foid);
            }
            if (fhead > params.maxNumPropagations) throw TimeOut();
        }
    } else {
        foreach (it, _faultCone) {
            setDcone     (*it, false);
        }
        _faultCone.clear();
    }
}

lbool
CircuitSAToag::getValue(Lit l)
{
    unsigned u = getData(l.id());
    char     c = (l.ckt() ? u>>2 : u) & 0x3;
    lbool res = toLbool(c) ^ l.sign();
    return res;
}

void
CircuitSAToag::printValue(unsigned id) 
{
    lbool v0 = getValue(Lit(id, 0));
    lbool v1 = getValue(Lit(id, 1));
    printf("%d:", id);
    v0.print();
    printf("/");
    v1.print();
}

void
CircuitSAToag::setValue(unsigned id, bool ckt, lbool val)
{
    //printf("setval %d %d %d\n", id, ckt, int(val));
    unsigned u = getData(id);
    if (!isInDcone(id)) {
        u &= ~0xF;
        u |= (toInt(val)<<2) | toInt(val);
    } else {
        u &= ~(ckt ? 0x3<<2 : 0x3);
        u |= ckt ? toInt(val)<<2 : toInt(val);
    }
    setData(id, u);
}

bool
CircuitSAToag::enqueue(Lit l)
{
    lbool v = getValue(l);
    DEBUG(printf("enqueue "); if (isInDcone(l.id())) printf("DC"); l.print(); printf(" (%d)\n", int(v)));
    if (v == l_True ) return true;
    if (v == l_False) return false;

    setValue(l.id(), l.ckt(), !l.sign());
    _trail.push_back(l);
    return true;
}

bool
CircuitSAToag::propagate()
{
    while (_qhead < _trail.size()) {
        Lit p = _trail[_qhead++];
        
        if (!implyO(p)) return false;

        if (!isInDcone(p.id())) {
            foreach (it, Fanouts(_graph, lit2ref(p))) {
                Ref ref = *it;
                DEBUG(printf("%s%d -> %d\n", _graph.isInverted(ref)?"~":"", p.id(), ref2id(ref)));
                if (!implyI(ref2lit(ref, p.ckt()) ^ p.sign()))
                    return false;
                if (isInDcone(ref2id(ref))) {
                    if (!implyI(ref2lit(ref, !p.ckt()) ^ p.sign()))
                        return false;
                }
            }
        } else {
            foreach (it, Fanouts(_graph, lit2ref(p))) {
                Ref ref = *it;
                DEBUG(printf("%s%d -> %d\n", _graph.isInverted(ref)?"~":"", p.id(), ref2id(ref)));
                if (!implyI(ref2lit(ref, p.ckt()) ^ p.sign()))
                    return false;
            }
        }
    }
    return true;
}

bool
CircuitSAToag::implyI(Lit l)
{
    if (stats.numPropagations++ > params.maxNumPropagations)
        throw TimeOut();

    DEBUG(printf("implyI "); l.print(); printf("\n"));
    Ref               ref  = lit2ref(l);
    oagAi::Node::Type type = _graph.getNodeType(ref);

    if (type == oagAi::Node::AND) {
        Lit lt = ref2lit(_graph.getAndLeft (ref), l.ckt());
        Lit rt = ref2lit(_graph.getAndRight(ref), l.ckt());

        if (!l.sign()) {
            // 1 input
            // (~l + ~r + o)
            if (getValue(lt)==l_True and getValue(rt)==l_True)
                return enqueue(l);
            // XXX: BUG?!
            if (getValue(Lit(l.id(), l.ckt())) == l_False) {
            //if (getValue(ref2lit(ref, l.ckt())) == l_False) {
                if (getValue(lt)==l_True) return enqueue(!rt);
                if (getValue(rt)==l_True) return enqueue(!lt);
            }
        } else {
            // 0 input
            // (l + ~o)
            // (r + ~o)
            return enqueue(l);
        }
    } else {
        // D value reached output?
        if (isCO(l.id())) {
            DEBUG(printf("PO: "); printValue(l.id()); printf("\n"));
            if (getValue(Lit(l.id(), !l.ckt(), l.sign())) == l_False)  {
                DEBUG(printf("reached PO %d\n", l.id()));
                _reachedPO = true;
            }
        }
        // buffer and/or output
        return enqueue(l);
    }
    return true;
}

bool
CircuitSAToag::implyO(Lit l)
{
    DEBUG(printf("implyO "); l.print(); printf("\n"));
    Ref               ref  = lit2ref(l);
    oagAi::Node::Type type = _graph.getNodeType(ref);

    if (type == oagAi::Node::AND) {
        Lit lt = ref2lit(_graph.getAndLeft (ref), l.ckt());
        Lit rt = ref2lit(_graph.getAndRight(ref), l.ckt());

        if (!l.sign()) {
            // 1 output
            // (l + ~o)
            // (r + ~o)
            return enqueue(lt) and enqueue(rt);
        } else {
            // 0 output
            // (~l + ~r + o)
            lbool lv = getValue(lt);
            if (lv == l_False) {
                return true;
            } else if (lv == l_True) {
                return enqueue(!rt);
            }

            lbool rv = getValue(rt);
            if (rv == l_False) {
                return true;
            } else if (rv == l_True) {
                return enqueue(!lt);
            }

            assert(lv==l_Undef and rv==l_Undef);
            _jFront.push_back(l);
            return true;
        }
    } else if (type == oagAi::Node::TERMINAL) {
        Ref inp = _graph.getTerminalDriver(ref);

        if (_graph.isNull(inp)) {
            // PI
            return true;
        } else {
            // buffer
            return enqueue(ref2lit(inp, l.ckt()) ^ l.sign());
        }
    } else {
        assert(type == oagAi::Node::CONSTANT0);
        return (getValue(ref2lit(_graph.constantOne(), 0)) == l_True);
    }
}

Lit
CircuitSAToag::decide()
{
    if (!_reachedPO) {
        // still need to reach PO
        DEBUG(printf("trying to reach PO\n"));
        _graph.newTraversalID();
        vector<Ref> worklist(1, _faultSite);
        while (!worklist.empty()) {
            Ref r = worklist.back(); worklist.pop_back();
            if (_graph.isVisited(r)) continue;
            _graph.markVisited(r);
            assert(isInDcone(ref2id(r)));
            DEBUG(printf("checking to reach PO through "); printValue(ref2id(r)); printf("\n"));
            // can do something here?
            if (getValue(ref2lit(r, 0)) == l_Undef) return ref2lit(r, 0);
            if (getValue(ref2lit(r, 1)) == l_Undef) return ref2lit(r, 1);
            foreach (it, _graph.getFanout(r))
                worklist.push_back(*it);
        }
        // can't reach PO
        DEBUG(printf("couldn't reach PO!\n"));
        return lit_Error;
    } else {
        DEBUG(printf("trying to justify\n"));
        // try to justify something
        foreach (it, _jFront) {
            Lit l   = *it;
            Ref ref = lit2ref(l);
            Lit lt  = ref2lit(_graph.getAndLeft (ref), l.ckt());
            Lit rt  = ref2lit(_graph.getAndRight(ref), l.ckt());
            if (getValue(lt) == l_False) continue;
            if (getValue(rt) == l_False) continue;
            assertB (getValue(l.abs())==l_False
                    and getValue(lt)!=l_True
                    and getValue(rt)!=l_True) {
                printValue(l.id());  printf(" (");
                printValue(lt.id()); printf(" ");
                printValue(rt.id()); printf(")\n");
            }
            return lt;
        }
        // nothing to justify
        DEBUG(printf("nothing to justify!\n"));
        return lit_Undef;
    }
}

bool
CircuitSAToag::solve(Lit p)
{
    unsigned tlim = _trail.size();
    unsigned jlim = _jFront.size();
    bool     rlim = _reachedPO;

    if (enqueue(p) and propagate()) {
        Lit v = decide();

        DEBUG(printf("decided "); v.print(); printf("\n"));

        if (v == lit_Undef) return true;
        if (v != lit_Error) {
            DEBUG(printf("branch %d 0\n", v.id()));
            if (solve( v)) return true;
            DEBUG(printf("branch %d 1\n", v.id()));
            if (solve(!v)) return true;
        }
    }

    _reachedPO = rlim;
    while (_jFront.size() != jlim) {
        _jFront.pop_back();
    }
    while (_trail.size() != tlim) {
        setValue(_trail.back().id(), _trail.back().ckt(), l_Undef);
        _trail.pop_back();
    }
    _qhead = tlim;

    return false;
}

CircuitSAT::Result 
CircuitSAToag::solve(const vector<oagAi::Ref> &assumps)
{
    assert(_graph.getDirtyMarker() == _dirtyMarker);
    if (assumps.empty()) return SAT;
    _reachedPO = true;
    _qhead     = _trail.size();

    return TIMEOUT;
}

CircuitSAT::Result
CircuitSAToag::isSubstitutableH(oagAi::Ref src, oagAi::Ref old, oagAi::Ref tgt)
{
    DEBUG(printf("CHECK HALF!\n"));
    assert(_jFront.empty());
    unsigned tlim = _trail.size();
    _qhead     = tlim;
    _reachedPO = false;
    enqueue(!ref2lit(old, 0));
    enqueue(!ref2lit(tgt, 0));
    enqueue( ref2lit(tgt, 1));

    Result res;

    try {
        res = solve( ref2lit(src, 0)) ? SAT : UNSAT;
    } catch (TimeOut) {
        res = TIMEOUT;
    }

    while (!_jFront.empty()) {
        _jFront.pop_back();
    }
    while (_trail.size() != tlim) {
        setValue(_trail.back().id(), _trail.back().ckt(), l_Undef);
        _trail.pop_back();
    }

    return res;
}

CircuitSAT::Result 
CircuitSAToag::isSubstitutable(oagAi::Ref src, oagAi::Ref tgt)
{
    DEBUG(printf("checking to sub %d for %d\n", src, tgt));
    assert(_graph.getDirtyMarker() == _dirtyMarker);

    Result res;
    if (_graph.getNodeType(tgt) == oagAi::Node::TERMINAL) {
#if 0
        // check if refs already have values
        lbool vsrc = getValue(ref2lit(src, 0));
        lbool vtgt = getValue(ref2lit(tgt, 0));
        if (vtgt != l_Undef) {
            return (vsrc == vtgt) ? UNSAT : SAT;
        }
        if (vsrc != l_Undef and vtgt != l_Undef) {
            return (vsrc == vtgt) ? UNSAT : SAT;
        }
#endif
        Ref ref0 = _graph.getTerminalDriver(tgt);
        _graph.setTerminalDriver(tgt, _graph.getNull());

        stats.numPropagations = 0;
        _faultSite = tgt;
        try {
            markDcone(ref2id(tgt), true);

            assert(_trail.empty());
            _qhead = 0;
            check(enqueue(ref2lit(_graph.constantOne(), 0)));
            check(propagate());
            //printf("%d from constant\n", stats.numPropagations);

            res = isSubstitutableH(src, ref0, tgt);
            if (res == UNSAT) {
                res = isSubstitutableH(_graph.notOf(src), _graph.notOf(ref0), _graph.notOf(tgt));
            }
        } catch (TimeOut) {
            res = TIMEOUT;
        }

        while (!_trail.empty()) {
            setValue(_trail.back().id(), _trail.back().ckt(), l_Undef);
            _trail.pop_back();
        }

        markDcone(ref2id(tgt), false);
        _graph.setTerminalDriver(tgt, ref0);
        DEBUG(printf("%d propagations\n", stats.numPropagations));
    } else {
        // TODO: implement me
        res = TIMEOUT;
    }
    _dirtyMarker = _graph.getDirtyMarker();
    DEBUG(printf("subst result %d\n", res));
    return res;
}

void   
CircuitSAToag::substitute(oagAi::Ref src, oagAi::Ref tgt)
{
    assert(_graph.getDirtyMarker() == _dirtyMarker);
    if (_graph.getNodeType(tgt) == oagAi::Node::TERMINAL) {
        _graph.setTerminalDriver(tgt, src);
#if 0
        lbool val = getValue(ref2lit(src, 0));
        if (val != l_Undef) {
            Lit l = Lit(ref2id(tgt), 0);
            check(enqueue(val==l_True ? l: !l));
            check(propagate());
        }
#endif
    } else {
        assert("not implemented yet!\n" && 0);
    }
    _dirtyMarker = _graph.getDirtyMarker();
}

CircuitSAT* 
CircuitSAT::createOAG(oagAi::Graph &g, const std::vector<oagAi::Ref> &o)
{
    return new CircuitSAToag(g, o);
}

}

// vim:et:
