#include "oagRedunFlattener.h"
#include "oagFunc.h"
#include "oagFuncQueryOcc.h"

using namespace std;
using namespace oa;
#define foreach(i,x) for(typeof((x).begin()) i=(x).begin(); i!=(x).end(); ++i)

namespace {
oaNativeNS ns;
oaString   str;

template <class T>
const char *
getName(T* inst) {
    inst->getName(ns, str);
    return str;
}

template <>
const char *
getName(oaInstTerm* it) {
    oaString s1;
    it->getInst()->getName(ns, str);
    it->getTerm()->getName(ns, s1);
    str += ':';
    str += s1;
    return str;
}

oaBitNet*
toBitNet(oaNet *n)
{
    return n->getBit(0);
}
oaOccBitNet*
toBitNet(oaOccNet *n)
{
    return n->getBit(0);
}
oaModBitNet*
toBitNet(oaModNet *n)
{
    return n->getBit(0);
}
}
namespace oagRedun {

/**
 * Checks that all nets are driven at the block level. Nets inside
 * instantiated blocks are not checked.
 */
bool
hasFloatingNets(oaBlock *block)
{
    set<oaNet*> drivenNets;

    oaIter<oaInst> instIter(block->getInsts());
    while (oaInst *inst = instIter.getNext()) {
        oaIter<oaInstTerm> it(inst->getInstTerms(oacInstTermIterAll));
        oaInstTerm *iterm;
        bool driven = false;
        while ((iterm = it.getNext())) {

            if (iterm->getTerm()->getTermType() == oacInputTermType) {
                oaNet *net = iterm->getNet();
                if (drivenNets.find(net) != drivenNets.end()) 
                    continue;
                if (net->getSigType() == oacTieHiSigType)
                    driven = true;
                if (net->getSigType() == oacTieLoSigType)
                    driven = true;

                if (!driven) {
                    oaIter<oaInstTerm> it2(net->getInstTerms(oacInstTermIterAll));
                    oaInstTerm *iterm2;
                    while ((iterm2 = it2.getNext())) {
                        if (iterm2->getTerm()->getTermType() == oacOutputTermType) {
                            driven = true;
                            break;
                        }
                    }
                }
                if (!driven) {
                    oaIter<oaTerm> it3(net->getTerms(oacTermIterAll));
                    oaTerm *iterm3;
                    while ((iterm3 = it3.getNext())) {
                        if (iterm3->getTermType() == oacInputTermType) {
                            driven = true;
                            break;
                        }
                    }
                }

                if (!driven) {
                    //printf("Net %s is undriven, adding an input terminal.\n", getName(net));
                    //oaName name;
                    //net->getName(name);
                    //oaTerm::create(net, name, oacInputTermType);
                    return true;
                } else {
                    drivenNets.insert(net);
                }
            }
        }
    }
    return false;
}

oagAi::Ref
Flattener::getTerm(oaInstTerm *term) const
{
    if (term->getTerm()->getTermType() == oacInputTermType) {
        map<oaInstTerm*, oagAi::Ref>::const_iterator it(_iInstTerms.find(term));
        assert(it != _iInstTerms.end());
        return it->second;
    } else if (term->getTerm()->getTermType() == oacOutputTermType) {
        //oaBitNet *net = oagFunc::toBitNet(term->getNet());
        oaBitNet *net = toBitNet(term->getNet());
        map<oaBitNet*, oagAi::Ref>::const_iterator it(_drivers.find(net));
        assert(it != _drivers.end()); 
        return it->second;
    }

    assert(0);
}

void
Flattener::setDriver(oaBitNet *net, oagAi::Ref ref)
{
    bool res = _drivers.insert(make_pair(net, ref)).second;
    assert(res && "driver already exists!");
}

oagAi::Ref
Flattener::getDriver(oaBitNet *net) const
{
    map<oaBitNet*, oagAi::Ref>::const_iterator it2;

    it2 = _drivers.find(net);
    if (it2 != _drivers.end()) 
        return it2->second;

    oa::oaBitNet *equivNet;
    oa::oaIter<oa::oaBitNet> equivIter(net->getEquivalentNets());
    while ((equivNet = equivIter.getNext())) {
        it2 = _drivers.find(equivNet);
        if (it2 != _drivers.end()) 
            return it2->second;
    }

    // some designs actually have undriven outputs. hmph.
    return oagAi::Graph::getNull();
}

void
Flattener::setupInputs()
{
    oaTerm *term;
    oaIter<oaTerm> termIter(_block->getTerms(oacTermIterSingleBit));
    while ((term = termIter.getNext())) {
        if (term->getTermType() == oacInputTermType) {
            oagAi::Ref ref = _are.getRef(_are.genericNewTerminal());

            //printf("INPUT %s ", getName(term));
            //printf("net %s ", getName(term->getNet()));
            //printf("ref %d\n", ref);

            setDriver(toBitNet(term->getNet()), ref);
            _inputs.push_back(ref);
        }
    }

    // tie constant nets
    oaNet *net;
    oaIter<oaNet> it(_block->getNets(oacNetIterSingleBit));
    while ((net = it.getNext())) {
        switch (net->getSigType()) {
            case oacTieHiSigType:
                setDriver(toBitNet(net), _graph.constantOne());
                break;
            case oacTieLoSigType:
                setDriver(toBitNet(net), _graph.constantZero());
                break;
            default:
                break;
        }
    }
}

void 
Flattener::setupInsts()
{
    using oagFunc::OccGraph;
    using oagFunc::OccRef;

    list<OccRef> states;
    const bool blackBoxSequentials = false;

    oaIter<oaInst> iter(_block->getInsts());
    while (oaInst *inst = iter.getNext()) {

        oaOccurrence *occ = inst->getMaster()->getTopOccurrence();
        oagFunc::QueryOcc query(occ->getDesign(), &_are);

        states.clear();
        OccGraph::getStates(occ, states);

        oaInstTerm *term;
        oaIter<oaInstTerm> termIter(inst->getInstTerms(oacTermIterSingleBit));
        while ((term = termIter.getNext())) {
            if (term->getTerm()->getTermType() == oacInputTermType) {
                oaNet    * n = term->getNet();
                oaOccNet *on = term->getTerm()->getOccTerm()->getNet();

                OccRef instT = OccGraph::getNetToAiConnection(toBitNet(on));
                oagAi::Ref ref = _are.getRef(_are.genericNewTerminal());
                query.set(instT, _are.getFunc(ref));

                _iInstTerms.insert(make_pair(term, ref));
            }
        }

        if (!blackBoxSequentials) {
            foreach (it, states) {
                oagAi::Ref ref = _are.getRef(_are.genericNewTerminal());
                query.set(*it, _are.getFunc(ref));
                // mark as CI
                _inputs.push_back(ref);
            }
        }

        termIter = inst->getInstTerms(oacTermIterSingleBit);
        while ((term = termIter.getNext())) {
            if (term->getTerm()->getTermType() == oacOutputTermType) {
                oagAi::Ref ref = _are.getRef(_are.genericNewTerminal());

                oaNet    * n = term->getNet();
                oaOccNet *on = term->getTerm()->getOccTerm()->getNet();

                if (states.empty() or !blackBoxSequentials) {
                    OccRef instT = OccGraph::getNetToAiConnection(toBitNet(on));
                    query.get(instT);
                    _graph.setTerminalDriver(ref, _are.getRef(query.get(instT)));
                }

                setDriver(toBitNet(term->getNet()), ref);
            }
        }

        if (!blackBoxSequentials) {
            foreach (it, states) {
                OccRef nextState = OccGraph::getNextState(*it);
                oagAi::Ref ref0 = _are.getRef(query.get(nextState));
                oagAi::Ref ref  = _are.getRef(_are.genericNewTerminal());
                _graph.setTerminalDriver(ref, ref0);
                // mark as CO
                _outputs.push_back(ref);
                _graph.incrementExternalReferences(ref);
            }
        }
    }

    // internal stitching
    foreach (it, _iInstTerms) {
        oaInstTerm *term = it->first;
        oagAi::Ref     ref  = it->second;

        oagAi::Ref     ref0 = getDriver(toBitNet(term->getNet()));
        assert(!oagAi::Graph::isNull(ref0));
        _graph.setTerminalDriver(ref, ref0);
    }
}

void 
Flattener::setupOutputs()
{
    oaTerm *term;
    oaIter<oaTerm> termIter(_block->getTerms(oacTermIterSingleBit));
    while ((term = termIter.getNext())) {
        if (term->getTermType() == oacOutputTermType) {
            oagAi::Ref ref0 = getDriver(toBitNet(term->getNet()));
            // floating output
            if (oagAi::Graph::isNull(ref0)) continue;

            oagAi::Ref ref  = _are.getRef(_are.genericNewTerminal());
            _graph.setTerminalDriver(ref, ref0);
            _outputs.push_back(ref);
            _graph.incrementExternalReferences(ref);
        }
    }
}

Flattener::Flattener(oaDesign *d, oagAi::Graph &g, vector<oagAi::Ref> &i, vector<oagAi::Ref> &o)
    : _block(d->getTopBlock()), _graph(g), _inputs(i), _outputs(o), _are(_graph)
{
    vector<oaNet*> undriven;
    if (hasFloatingNets(_block))
        throw Flattener::FloatingNetException(); 
    setupInputs();
    setupInsts();
    setupOutputs();
    if (g.hasCombinationalCycle())
        throw Flattener::CycleException();
    
}

Flattener::~Flattener()
{
#ifndef MAP
    //_iInstTerms->destroy();
    //_drivers->destroy();
#endif
}

}

// vim:et:
