/*
Author: Kai-hui Chang <changkh@eecs.umich.edu>
*/

#include <iostream>
#include <string>
#include <set>
#include "oagResynInternalCkt.h"
#include "oagResynUtil.h"

using namespace std;
using namespace oa;

//#define DEBUG

namespace oagResyn {

// *****************************************************************************
// internalCkt()
//
/// \brief Constructor.
//
// *****************************************************************************
internalCkt::internalCkt(oa::oaDesign *design) {
    assert(design);
    block= design->getTopBlock();
    assert(block);
    this->design = design;
    tie0Wire= NULL;
    tie1Wire= NULL;
    setup();
}

// *****************************************************************************// analyzeCircuit()
//
/// \brief Analyze the circuit. 
///
/// Identify state registers, PIs and POs. Also build cell library for simulation. Will be called by setup.
//
// *****************************************************************************
void internalCkt::analyzeCircuit() {
    set<oaInst*> stateSet;
    set<oaNet*>::iterator setNetIter;
    set<oaInst*>::iterator setIter;

    // Identify PIs/POs
    try
    {
        oaIter<oaTerm>   oaTIter(block->getTerms());
        while (oaTerm*  term = oaTIter.getNext())
        {
            oaTermType termType = term->getTermType();
            oaNet* net= term->getNet();
            if( termType == oacOutputTermType)
            {
                oaIter<oaBitNet>   oaBNIter(net->getSingleBitMembers());
                while (oaNet*  bitNet = static_cast<oaNet*>(oaBNIter.getNext()))
                    outputNets.push_back(bitNet);
            }
            if( termType == oacInputTermType)
            {
                oaIter<oaBitNet>   oaBNIter(net->getSingleBitMembers());
                while (oaNet*  bitNet = static_cast<oaNet*>(oaBNIter.getNext()))
                {
                    inputNets.push_back(bitNet);
                }
            }
        }
        oaIter<oaNet>   oaNetIter(block->getNets());
        // Find all bitNets
        while (oaNet*  net = oaNetIter.getNext())
        {
            oaIter<oaBitNet>   oaBNIter(net->getSingleBitMembers());
            while (oaNet*  bitNet = static_cast<oaNet*>(oaBNIter.getNext()))
            {
                if (net->getNumBits() == 1 && net != bitNet)
                    allBitNets.insert(net);
                else
                    allBitNets.insert(bitNet);
            }
        }
    }
    catch (oa::oaException &excp)
    {
        printf("Error: %s\n", (const char *)excp.getMsg());
        return;
    }
    
    try
    {
        oaIter<oaInst>   oaInstIter(block->getInsts());
        while (oaInst*  inst = oaInstIter.getNext()) 
        {
            string modName;
            
            instGetModName(inst, modName);
            // Check if module name contains DFF.
            if (modName.find(dffKey.c_str()) != string::npos)
                stateSet.insert(inst);
            else
                libManager.addCell(modName);
        }
    }
    catch (oa::oaException &excp)
    {
        printf("Error: %s\n", (const char *)excp.getMsg());
        return;
    }
    // Move stateSet to stateInsts
    for (setIter= stateSet.begin(); setIter != stateSet.end(); setIter++)
    {
        stateInsts.push_back(*setIter);
    }
    for (setNetIter= allBitNets.begin(); setNetIter != allBitNets.end(); setNetIter++)
        allBitNetsVec.push_back(*setNetIter);
}

// *****************************************************************************
// getStates()
//
/// \brief Gets all state instances inside this design. 
// *****************************************************************************
void internalCkt::getStates(vector<oaInst*>& result) {
    result= stateInsts;
}

// *****************************************************************************
// getAllNets()
//
/// \brief Gets all nets in the circuit (oaBitNet).
// *****************************************************************************
void internalCkt::getAllNets(vector<oaNet*>& result) {
    
    result= allBitNetsVec;
}

// *****************************************************************************
// getClockNets()
//
/// \brief Gets all nets marked as clock by DFF's clock.
// *****************************************************************************
void internalCkt::getClockNets(vector<oaNet*>& result) {
    set<oaNet*>::iterator setIter;
    
    result.clear();
    for (setIter= clockNets.begin(); setIter != clockNets.end(); setIter++)
        result.push_back(*setIter);
}

// *****************************************************************************
// getInputs()
//
/// \brief Gets all primary input oaNets inside this design.
// *****************************************************************************
void internalCkt::getInputs(vector<oaNet*>& result) {
    result= inputNets;
}

// *****************************************************************************
// getOutputs()
//
/// \brief Gets all primary output oaNets inside this design.
// *****************************************************************************
void internalCkt::getOutputs(vector<oaNet*>& result) {
    result= outputNets;
}

// *****************************************************************************
// printInOutState()
//
/// \brief Print all PI/PO/State
// *****************************************************************************
void internalCkt::printInOutState()
{
    vector<oaInst*>::iterator instIter;
    vector<oaNet*>::iterator netIter;
    oaInst* myInst;
    oaNet* myNet;

    cout<<"Input terminals:"<<endl;
    for (netIter= inputNets.begin(); netIter != inputNets.end(); netIter++)
    {
        myNet= *netIter;
        cout<<getNameFromOA(myNet)<<endl;
    }
    cout<<"Output terminals:"<<endl;
    for (netIter= outputNets.begin(); netIter != outputNets.end(); netIter++)
    {
        myNet= *netIter;
        cout<<getNameFromOA(myNet)<<endl;
    }
    cout<<"State registers:"<<endl;
    for (instIter= stateInsts.begin(); instIter != stateInsts.end(); instIter++)
    {
        myInst= *instIter;
        cout<<getNameFromOA(myInst)<<endl;
    }
}

// *****************************************************************************
// instGetModName()
//
/// \brief Get the module name of an OA instance
// *****************************************************************************
void internalCkt::instGetModName(oaInst* inInst, string& modName) {
    oa::oaNativeNS ns;
    oa::oaString myName;
    oa::oaInstHeader* myHeader;
                        
    myHeader= inInst->getHeader();
    myHeader->getCellName(ns, myName);
    modName= static_cast<const char*>(myName);
}

// *****************************************************************************
// buildInternalCircuit()
//
/// \brief Build some internal data structures for simulation.
// *****************************************************************************
void internalCkt::buildInternalCircuit()
{
    simInst* newInst;
    simWire* newWire;
    string cellName;
    vector<oaInst*>::iterator oaInstIter;
    vector<oaNet*>::iterator oaWireIter;
    set<oaNet*>::iterator netIter;

    // Build all wires first and the map from nets to wires. Only one wire
    // will be created for equivalent net
    try
    {
        // Create all nets first
        for (netIter= allBitNets.begin(); netIter != allBitNets.end();
             netIter++)
        {
            oaNet* net= *netIter;
            oaBitNet* bitNet= static_cast<oaBitNet*>(net);
            bool isTie0, isTie1;
            
            // Also identify tie0/tie1 here
            if (tie0.compare(getNameFromOA(net)) == 0)
            {
                isTie0= true;
                isTie1= false;
            }
            else if (tie1.compare(getNameFromOA(net)) == 0)
            {
                isTie0= false;
                isTie1= true;
            } 
            else
            {
                isTie0= false;
                isTie1= false;
            }
            newWire= static_cast<simWire*>(new(simWire));
            newWire->origWire= bitNet;
            netSimWireMap[bitNet]= newWire;

            if (isTie0 || isTie1)
            {
                newWire->isPI= true;
                if (isTie0)
                    tie0Wire= bitNet;
                else
                    tie1Wire= bitNet;
            } 
        }
    }
    catch (oa::oaException &excp)
    {
        printf("Error: %s\n", (const char *)excp.getMsg());
        return;
    }
    try
    {
        oaIter<oaInst>   oaInstIter(block->getInsts());
        while (oaInst*  inst = oaInstIter.getNext())
        {
            // Add inst to circuit
            newInst= static_cast<simInst*>(new(simInst)); 
            newInst->origInst= inst;
            instGetModName(inst, cellName);
            newInst->cell= libManager.lookUpCell(cellName);
            instSimInstMap[inst]= newInst;
        }
    }
    catch (oa::oaException &excp)
    {
        printf("Error: %s\n", (const char *)excp.getMsg());
        return;
    }
    // Mark PI/PO/State
    for (oaInstIter= stateInsts.begin(); oaInstIter != stateInsts.end(); oaInstIter++)
    {   
        oaInst* inst= *oaInstIter;

        newInst= instSimInstMap[*oaInstIter];
        newInst->isStateReg= true;
        oaIter<oaInstTerm>   oaITIter(inst->getInstTerms());
        while (oaInstTerm*  instTerm = oaITIter.getNext())
        {
            oaTerm* term = instTerm->getTerm();
            oaTermType termType = term->getTermType();
            oaNet* net = instTerm->getNet();
            newWire= netSimWireMap[net];

            if( termType == oacOutputTermType)
            {
                if (dffQ.compare(getNameFromOA(term)) == 0)
                    newWire->isRegOut= true;
                else if (dffQN.compare(getNameFromOA(term)) == 0)
                {
                    newWire->isRegOut= true;
                    newWire->isRegOutQN= true;
                    oaIter<oaInstTerm>   oaITIter2(inst->getInstTerms());
                    while (oaInstTerm*  instTerm2 = oaITIter2.getNext())
                    {
                        oaTerm* term2 = instTerm2->getTerm();
                        oaTermType termType2 = term2->getTermType();
                        if (dffQ.compare(getNameFromOA(term2)) == 0)
                        {
                            oaNet* net3= instTerm2->getNet();
                            newWire->regQWire= net3;
                        }
                    }
                }
                stateNets.push_back(net);
            }
            else 
            {
                if (dffD.compare(getNameFromOA(term)) == 0)
                {
                    newWire->isRegD= true;
                    stateInNets.push_back(net);
                }
                else if (dffCLK.compare(getNameFromOA(term)) == 0)
                {
                    clockNets.insert(net);
                }
            }
        }
    }
    for (oaWireIter= inputNets.begin(); oaWireIter != inputNets.end(); oaWireIter++)
    {
        newWire= netSimWireMap[*oaWireIter];
        newWire->isPI= true;
    }
    for (oaWireIter= outputNets.begin(); oaWireIter != outputNets.end(); oaWireIter++)
    {
        newWire= netSimWireMap[*oaWireIter];
        newWire->isPO= true;
    }
}


// *****************************************************************************
// levelizeInternalCircuit()
//
/// \brief Levelize the OpenAccess circuit for simulation.
// *****************************************************************************
void internalCkt::levelizeInternalCircuit()
{
    simWire* myWire;
    simInst* myInst;
    oaInst* inst;
    oaInstTerm *instTerm;
    vector<oaInst*> sInsts;
    vector<oaInst*>::iterator instIter;
    unsigned int count, count2;
    int level;
    bool found;
    vector<oaNet*>::iterator netIter;

    levelizedInst.clear();
    // Use wire's simVec to determine if all inputs are available. If 0,
    // not yet available. If 1, already available.
   
    // clear sim value
    for (netIter= allBitNetsVec.begin(); netIter != allBitNetsVec.end(); netIter++)
    {
        myWire= netGetSimWire(*netIter);
        setVector(*netIter, 0, 0);
    }
    // set PI and RegOut's value to 1
    for (netIter= allBitNetsVec.begin(); netIter != allBitNetsVec.end(); netIter++)
    {
        myWire= netGetSimWire(*netIter);
        if (myWire->isPI || myWire->isRegOut)
        {
            setVector(*netIter, 1, 0);
        }
    }
    oaIter<oaInst> oaIIter(block->getInsts());
    
    // Add combinational instances to sInsts
    while (oaInst* inst= oaIIter.getNext())
    {
        myInst= instGetSimInst(inst);
        if (myInst->isStateReg == false)
            sInsts.push_back(inst);
        myInst->visited= false;
    }
    count= 0;
    count2= 0;
    while (levelizedInst.size() < sInsts.size())
    {
        count2++;
        if (count2 > sInsts.size())
            break;
        for (instIter= sInsts.begin(); instIter != sInsts.end(); instIter++)
        {
            inst= *instIter;
            myInst= instGetSimInst(inst);
            if (myInst->visited)
                continue;
            level= 0;
            oaIter<oaInstTerm>   oaITIter(inst->getInstTerms(oacInstTermIterNotImplicit | oacInstTermIterEquivNets));
            found= false;
            for (instTerm = oaITIter.getNext(); instTerm; instTerm = oaITIter.getNext())
            {
                oaTerm* term = instTerm->getTerm();
                oaTermType termType = term->getTermType();
                if( termType == oacInputTermType)
                {
                    oaNet* myNet= instTerm->getNet();
                    myWire= netGetSimWire(myNet);
                    if (myWire && myWire->simVec == 0)
                    {
                        found= true;
                        break;
                    }
                }
            }
            if (found == false)
            {
                levelizedInst.push_back(inst);
                myInst->visited= true;
                count++;
                oaIter<oaInstTerm>   oaITIter(inst->getInstTerms(oacInstTermIterNotImplicit | oacInstTermIterEquivNets));
                for (instTerm = oaITIter.getNext(); instTerm; instTerm = oaITIter.getNext())
                {
                    oaTerm* term = instTerm->getTerm();
                    oaTermType termType = term->getTermType();
                    if( termType == oacOutputTermType)
                    {
                        oaNet* myNet= instTerm->getNet();
                        myWire= netGetSimWire(myNet);
                        if (myWire)
                        {
                            setVector(myNet, 1, 0);
                        }
                    }
                }
            }
        }
    }
    if (count != sInsts.size())
        cout<<"Could not levelize internal circuit"<<endl;
    // clear visited flag of instances
    for (instIter= levelizedInst.begin(); instIter != levelizedInst.end(); instIter++)
    {
        (instGetSimInst(*instIter))->visited= false;
    }
    // Set tie0/tie1
    if (tie0Wire)
    {
        myWire= netGetSimWire(tie0Wire);
        myWire->simVec= 0;
        setVectorEquiNet(tie0Wire, 0, 0);
    }
    if (tie1Wire)
    {
        myWire= netGetSimWire(tie1Wire);
        myWire->simVec= ~0;
        setVectorEquiNet(tie1Wire, ~0, 0);
    }
}

// *****************************************************************************
// getStateNets()
//
/// \brief Get nets connected to the outputs of state registers.
//
// *****************************************************************************
void internalCkt::getStateNets(vector<oaNet*>& result)
{
    result= stateNets;
}

// *****************************************************************************
// getStateInNets()
//
/// \brief Get nets connected to the inputs of state registers.
//
// *****************************************************************************
void internalCkt::getStateInNets(vector<oaNet*>& result)
{
    result= stateInNets;
}


// *****************************************************************************
// setup()
//
/// \brief Setup internal data structure. 
///
/// Note that setKeyWords must be called before this routine to be effective.
//
// *****************************************************************************
void internalCkt::setup()
{
    analyzeCircuit();
    libManager.genLogicInfo(design);
    buildInternalCircuit();
    levelizeInternalCircuit();
}

// *****************************************************************************
// cleanUp()
//
/// \brief Destroy datastructures built and free the memory.
//
// *****************************************************************************
void internalCkt::cleanUp()
{
    stateInsts.clear();
    inputNets.clear();
    outputNets.clear();
    stateNets.clear();
    levelizedInst.clear();
    allBitNets.clear();
}

// *****************************************************************************
// ~internalCkt()
//
/// \brief Destructor.
//
// *****************************************************************************
internalCkt::~internalCkt()
{
    cleanUp();
}

// *****************************************************************************
// getInputConeInsts()
//
/// \brief Return the oaInst in the input cones of nets along with PIs.
//
// *****************************************************************************
void internalCkt::getInputConeInsts(vector<oaNet*>& nets, vector<oaInst*>& insts, vector<oaNet*>& inputNets)
{
    vector<oaNet*>::iterator netIter;
    vector<oaInst*> stack;
    vector<oaInst*> visitedInst;
    vector<oaInst*>::iterator instIter;
    oaInstTerm* instTerm, *driverTerm;
    oaInst* inst, *inst2;
    simInst* simInst2, *simInst3;
    oaNet* net2;
    set<oaNet*> inputSet;
    set<oaNet*>::iterator setIter;

    insts.clear();
    inputNets.clear();
    // Get wires from nets. Add instances to stack.
    for (netIter= nets.begin(); netIter != nets.end(); netIter++)
    {
        instTerm= util::netGetInstDriver(*netIter);
        if (instTerm)
            inst= instTerm->getInst();
        else
        {
            inputSet.insert(*netIter);
            continue;
        }
        if (inst)
        {
            simInst2= instGetSimInst(inst);
            if (simInst2->visited == false)
            {
                stack.push_back(inst);
                simInst2->visited= true;
                visitedInst.push_back(inst);
            }
            if (simInst2->isStateReg)
                inputSet.insert(*netIter);
        }
    }
    while (!stack.empty())
    {
        inst= stack.back();
        stack.pop_back();
        simInst2= instGetSimInst(inst);
        if (simInst2->isStateReg == false)
        {
            insts.push_back(inst);
            // Traverse all its inputs. 
            oaIter<oaInstTerm>   oaITIter(inst->getInstTerms(oacInstTermIterNotImplicit | oacInstTermIterEquivNets));
            for (instTerm = oaITIter.getNext(); instTerm; instTerm = oaITIter.getNext())
            {
                oaTerm* term = instTerm->getTerm();
                oaTermType termType = term->getTermType();
                if( termType == oacInputTermType)
                {
                    net2= instTerm->getNet();
                    if (net2 == NULL)
                        continue;
                    driverTerm= util::netGetInstDriver(net2);
                    if (driverTerm == NULL)
                    {
                        inputSet.insert(net2);
                        continue;
                    }
                    inst2= driverTerm->getInst();
                    simInst3= instGetSimInst(inst2);
                    if (simInst3->visited == false)
                    {
                        simInst3->visited= true;
                        visitedInst.push_back(inst2);
                        stack.push_back(inst2);
                        if (simInst3->isStateReg)
                        {
                            inputSet.insert(net2);
                        }
                    }
                }
            }
        }
    }
    // clear visited flag
    for (instIter= visitedInst.begin(); instIter != visitedInst.end(); instIter++)
    {
        simInst3= instGetSimInst(*instIter);
        simInst3->visited= false;
    }
    for (setIter= inputSet.begin(); setIter != inputSet.end(); setIter++)
        inputNets.push_back(*setIter);
}

// *****************************************************************************
// instGetSimInst()
//
/// \brief Return the SimInst of the oaInst.
//
// *****************************************************************************
simInst* internalCkt::instGetSimInst(oaInst* inst)
{
    return instSimInstMap[inst];
}

// *****************************************************************************
// netGetSimWire()
//
/// \brief Return the SimWire of the oaNet.
//
// *****************************************************************************
simWire* internalCkt::netGetSimWire(oaNet* net)
{
    return netSimWireMap[net];
}

// *****************************************************************************
// setVectorEquiNet()
//
/// \brief Set a vector value to all the nets equivalent to net.
//
// *****************************************************************************
void internalCkt::setVectorEquiNet(oaNet* net, SimulationVector vec, int isw= 0)
{
    simWire* myWire;
    oaBitNet* bitNet= static_cast<oaBitNet*>(net);

    {
        oaBitNet*  bitNet2;
        oaIter<oaBitNet>   oaBNIter(bitNet->getEquivalentNets());
        for (bitNet2 = oaBNIter.getNext(); bitNet2; bitNet2 = oaBNIter.getNext())
        {
            myWire= netGetSimWire(bitNet2);
            {
                if (isw == 0)
                    myWire->simVec= vec;
                else
                    myWire->nextSimVec= vec;
            }
        }
    }
}



} // End of namespace
// vim: ci et sw=4 sts=4
