
/*
Author: Kai-hui Chang <changkh@eecs.umich.edu>
*/

#include <iostream>
#include "oagResynSATCkt.h"

using namespace std;
using namespace oa;

namespace oagResyn {

// *****************************************************************************
// genComparator()
//
/// \brief Generate a comparator for input1/input2. 
///
/// \param offsetBase The offset of literal number used in this CNF 
/// \param input1 Input literals' numbers for inputs1
/// \param input2 Input literals' numbers for inputs2. 
/// \param outputCNF Generated CNF.
/// \param outputLitNO The literal number of the output signal. it is also the largest litNO used in this CNF.
//
// *****************************************************************************
void satCkt::genComparator(int offsetBase, vector<int>& input1, 
                           vector<int>& input2, clauseType& outputCNF, 
                           int &outputLitNO)
{
    int size, i, currLitNO;
    vector<int> clause;

    outputCNF.clear();
    if (input1.size() != input2.size())
    {
        cout<<"The number of inputs to a comparator must be identical."<<endl;
        return;
    }
    size= input1.size();
    // Comparator: XOR of input pairs and then OR them.
    // XOR: -Z A B 0, -Z -A -B 0, Z A -B 0, Z -A B 0
    for (i= 0; i < size; i++)
    {
        currLitNO= offsetBase + i;
        // -Z A B
        clause.push_back(-currLitNO);
        clause.push_back(input1[i]);
        clause.push_back(input2[i]);
        outputCNF.push_back(clause);
        clause.clear();
        // -Z -A -B
        clause.push_back(-currLitNO);
        clause.push_back(-input1[i]);
        clause.push_back(-input2[i]);
        outputCNF.push_back(clause);
        clause.clear();
        // Z A -B
        clause.push_back(currLitNO);
        clause.push_back(input1[i]);
        clause.push_back(-input2[i]);
        outputCNF.push_back(clause);
        clause.clear();
        // Z -A B
        clause.push_back(currLitNO);
        clause.push_back(-input1[i]);
        clause.push_back(input2[i]);
        outputCNF.push_back(clause);
        clause.clear();
    }
    // Now generate OR: -Z A B C ... 0, -A Z 0, -B Z 0, -C Z 0, ...
    // -Z A B C ...
    outputLitNO= offsetBase + size + 1;
    clause.push_back(-outputLitNO);
    for (i= 0; i < size; i++)
    {
        currLitNO= offsetBase + i;
        clause.push_back(currLitNO);
    }
    outputCNF.push_back(clause);
    clause.clear();
    // -A Z 0, -B Z 0 ...
    for (i= 0; i < size; i++)
    {
        currLitNO= offsetBase + i;
        clause.push_back(-currLitNO);
        clause.push_back(outputLitNO);
        outputCNF.push_back(clause);
        clause.clear();
    }
}

// *****************************************************************************
// genOneCounter()
//
/// \brief Generate a one counter.
///
/// The counter counts the number of ONEs in its input. 
/// \param offsetBase The offset of literal numbers used in this CNF. 
/// \param input It is the inputs to the counter.
/// \param outputCNF Returns the generated CNF. 
/// \param output Returns the outputs of the circuit, LSB in lower index.
//
// *****************************************************************************
void satCkt::genOneCounter(int offsetBase, vector<int>& input, 
                         clauseType& outputCNF, vector<int>& output)
{
    // It is implemented by hierarchies of adders
    vector<vector<int> > inputVec[2];
    unsigned int i;
    int side;
    vector<int> oneVec;
    int litNO;
    clauseType adderCNF;
    vector<int> outVec;

    side= 0;
    outputCNF.clear();
    litNO= offsetBase;
    for (i= 0; i < input.size(); i++)
    {
        oneVec.push_back(input[i]);
        inputVec[side].push_back(oneVec);
        oneVec.clear();
    }

    do
    {
        // Take two inputVec, add them, and put the result into the other side
        // cout<<"Change side"<<endl;
        for (i= 0; i < inputVec[side].size() / 2; i++)
        {
            genAdder(litNO, inputVec[side][i * 2], inputVec[side][i * 2 + 1],
                     adderCNF, outVec);
            // printClause(adderCNF);                     
            litNO= outVec[outVec.size() - 1] + 1;
            // add adderCNF to output CNF
            outputCNF.insert(outputCNF.end(), adderCNF.begin(), adderCNF.end());
            inputVec[1-side].push_back(outVec);
        }
        // If there is anything not added, propogate to the next side
        if (inputVec[side].size() % 2 == 1)
            inputVec[1-side].push_back(inputVec[side][inputVec[side].size() - 1]);
        inputVec[side].clear();
        side= 1 - side;
    } while (inputVec[side].size() > 1);
    output= inputVec[side][0];
}

// *****************************************************************************
// genAdder()
//
/// \brief Generate an adder. 
/// \param offsetBase The offset of literal numbers used in this CNF. 
/// \param input1 Input literal numbers for input1. 
/// \param input2 Input literal numbers for input2. (If input1's size != input2's size, 0 will be used.)
/// \param outputCNF Returns the generated CNF.
/// \param output Returns the outputs of the circuit. The largest litNO used is the MSB of output.
//
// *****************************************************************************
void satCkt::genAdder(int offsetBase, vector<int>& input1, vector<int>& input2,
                      clauseType& outputCNF, vector<int>& output)
{
    unsigned int i;
    int count, S, COUT, litNO, C;
    vector<int> clause;
    int j;
    
    outputCNF.clear();
    output.clear();
    litNO= offsetBase;
    if (input1.size() < input2.size())
    {
        // put 0
        count= input2.size() - input1.size();
        for (j= 0; j< count; j++)
        {
            input1.push_back(litNO);
            clause.push_back(-litNO);
            outputCNF.push_back(clause);
            clause.clear();
            litNO++;
        }
    }
    if (input1.size() > input2.size())
    {
        // put 0
        count= input1.size() - input2.size();
        for (j= 0; j < count; j++)
        {
            input2.push_back(litNO);
            clause.push_back(-litNO);
            outputCNF.push_back(clause);
            clause.clear();
            litNO++;
        }
    }
    if (input1.size() < 1)
        cout<<"Cannot generate adders with number of inputs < 1"<<endl;
    // Now they have the same size, with "0" means the input is 0.
    // Generate a half adder first
    // Half adder: -S A B 0 -S -A -B 0 S A -B 0 S -A B 0 -COUT A 0 -COUT B 0
    // -A -B COUT 0 
    S= litNO++;
    COUT= litNO++;
    output.push_back(S);
    // -S A B 0
    clause.push_back(-S);
    clause.push_back(input1[0]);
    clause.push_back(input2[0]);
    outputCNF.push_back(clause);
    clause.clear();
    // -S -A -B 0
    clause.push_back(-S);
    clause.push_back(-input1[0]);
    clause.push_back(-input2[0]);
    outputCNF.push_back(clause);
    clause.clear();
    // S A -B 0
    clause.push_back(S);
    clause.push_back(input1[0]);
    clause.push_back(-input2[0]);
    outputCNF.push_back(clause);
    clause.clear();
    // S -A B 0
    clause.push_back(S);
    clause.push_back(-input1[0]);
    clause.push_back(input2[0]);
    outputCNF.push_back(clause);
    clause.clear();
    // -COUT A 0
    clause.push_back(-COUT);
    clause.push_back(input1[0]);
    outputCNF.push_back(clause);
    clause.clear();
    // -COUT B 0
    clause.push_back(-COUT);
    clause.push_back(input2[0]);
    outputCNF.push_back(clause);
    clause.clear();
    // -A -B COUT 0
    clause.push_back(-input1[0]);
    clause.push_back(-input2[0]);
    clause.push_back(COUT);
    outputCNF.push_back(clause);
    clause.clear();
    // Full adder: 
    for (i= 1; i < input1.size(); i++)
    {
        C= COUT;
        S= litNO++;
        COUT= litNO++;
        // COUT -B -C 0 
        clause.push_back(COUT);
        clause.push_back(-input2[i]);
        clause.push_back(-C);
        outputCNF.push_back(clause);
        clause.clear();
        // COUT -A -C 0
        clause.push_back(COUT);
        clause.push_back(-input1[i]);
        clause.push_back(-C);
        outputCNF.push_back(clause);
        clause.clear();
        // COUT -A -B 0
        clause.push_back(COUT);
        clause.push_back(-input1[i]);
        clause.push_back(-input2[i]);
        outputCNF.push_back(clause);
        clause.clear();
        // -COUT A B 0
        clause.push_back(-COUT);
        clause.push_back(input1[i]);
        clause.push_back(input2[i]);
        outputCNF.push_back(clause);
        clause.clear();
        // -COUT A C 0
        clause.push_back(-COUT);
        clause.push_back(input1[i]);
        clause.push_back(C);
        outputCNF.push_back(clause);
        clause.clear();
        // -COUT B C 0
        clause.push_back(-COUT);
        clause.push_back(input2[i]);
        clause.push_back(C);
        outputCNF.push_back(clause);
        clause.clear();
        // -S A B C 0
        clause.push_back(-S);
        clause.push_back(input1[i]);
        clause.push_back(input2[i]);
        clause.push_back(C);
        outputCNF.push_back(clause);
        clause.clear();
        // -S A -B -C 0
        clause.push_back(-S);
        clause.push_back(input1[i]);
        clause.push_back(-input2[i]);
        clause.push_back(-C);
        outputCNF.push_back(clause);
        clause.clear();
        // -S -A B -C 0
        clause.push_back(-S);
        clause.push_back(-input1[i]);
        clause.push_back(input2[i]);
        clause.push_back(-C);
        outputCNF.push_back(clause);
        clause.clear();
        // -S -A -B C 0
        clause.push_back(-S);
        clause.push_back(-input1[i]);
        clause.push_back(-input2[i]);
        clause.push_back(C);
        outputCNF.push_back(clause);
        clause.clear();
        // S A B -C 0
        clause.push_back(S);
        clause.push_back(input1[i]);
        clause.push_back(input2[i]);
        clause.push_back(-C);
        outputCNF.push_back(clause);
        clause.clear();
        // S A -B C 0
        clause.push_back(S);
        clause.push_back(input1[i]);
        clause.push_back(-input2[i]);
        clause.push_back(C);
        outputCNF.push_back(clause);
        clause.clear();
        // S -A B C 0
        clause.push_back(S);
        clause.push_back(-input1[i]);
        clause.push_back(input2[i]);
        clause.push_back(C);
        outputCNF.push_back(clause);
        clause.clear();
        // S -A -B -C 0
        clause.push_back(S);
        clause.push_back(-input1[i]);
        clause.push_back(-input2[i]);
        clause.push_back(-C);
        outputCNF.push_back(clause);
        clause.clear();
        output.push_back(S);
    }
    output.push_back(COUT);
}

// *****************************************************************************
// printLitMap()
//
/// \brief Print the lit->net map.
//
// *****************************************************************************
void satCkt::printLitMap(std::map<oa::oaNet*, int>& litNOMap)
{
    map<oa::oaNet*, int>::iterator mapIter;

    for (mapIter= litNOMap.begin(); mapIter != litNOMap.end(); mapIter++)
    {
        cout<<getNameFromOA((*mapIter).first)<<": "<<(*mapIter).second<<endl;
    }
}

// *****************************************************************************
// printClause()
//
/// \brief Print the clauses.
//
// *****************************************************************************
void satCkt::printClause(clauseType& inputCNF)
{
    clauseType::iterator iter1;
    vector<int>::iterator iter2;
        
    for (iter1= inputCNF.begin(); iter1 != inputCNF.end(); iter1++)
    {
        for (iter2= (*iter1).begin(); iter2 != (*iter1).end(); iter2++)
        {
            cout<<*iter2<<" ";
        }
        cout<<"0"<<endl;
    }
}

// *****************************************************************************
// printIntVec()
//
/// \brief Print the literals in a clause.
//
// *****************************************************************************
void satCkt::printIntVec(vector<int>& inputVec)
{
    vector<int>::iterator intIter;

    for (intIter= inputVec.begin(); intIter != inputVec.end(); intIter++)
        cout<<*intIter<<" ";
    cout<<endl;
}

// *****************************************************************************
// genMux()
//
/// \brief Generates a multiplexer. 
///
/// The maximum literal number used will be returned in the largestLitNO.
//
// *****************************************************************************
void satCkt::genMux(int input1, int input2, int sel, int output,
                    clauseType& outputCNF)
{
    vector<int> clause;
    
    // MUX: -Z -S B 0 -Z A S 0 Z S -A 0 Z -S -B 0
    outputCNF.clear();
    clause.push_back(-output);
    clause.push_back(-sel);
    clause.push_back(input2);
    outputCNF.push_back(clause);
    clause.clear();
    clause.push_back(-output);
    clause.push_back(sel);
    clause.push_back(input1);
    outputCNF.push_back(clause);
    clause.clear();
    clause.push_back(output);
    clause.push_back(sel);
    clause.push_back(-input1);
    outputCNF.push_back(clause);
    clause.clear();
    clause.push_back(output);
    clause.push_back(-sel);
    clause.push_back(-input2);
    outputCNF.push_back(clause);
    clause.clear();
}


// *****************************************************************************
// genInstMux()
//
/// \brief Generates CNF circuit for insts and insert MUXes into nets (except PIs). 
///
/// \param design oaDesign
/// \param offsetBase The offset of literal numbers used in this CNF. 
/// \param insts The instances that need to generate SAT-CNF.
/// \param allNets All the nets in the circuit.
/// \param muxedNets The nets that have been mux-inserted.
/// \param selLits The literals of the select lines for muxedNets
/// \param muxOutLits The literals of the output variables for muxedNets
/// \param outputCNF Returns the generated CNF
/// \param litNOMap Returns a map from oaNet to literal number
/// \param largestLitNO The largest litNO used in the generated CNF.
//
// *****************************************************************************
void satCkt::genInstMux(oa::oaDesign* design, int offsetBase, 
                        std::vector<oa::oaInst*>& insts,
                        std::vector<oa::oaNet*>& allNets,
                        std::vector<oa::oaNet*>& muxedNets,
                        std::vector<int>& selLits, std::vector<int>& muxOutLits,
                        clauseType& outputCNF, 
                        std::map<oa::oaNet*, int>& litNOMap,
                        int& largestLitNO)
{
    vector<oaNet*>::iterator netIter;
    vector<int> inputs, outputs, extras;
    vector<int> clause;
    optSimWire* myWire;
    int litNO;
    optInternalCkt* myCkt;
    oaNet* myNet;
    int currLit;
    clauseType muxCNF;

    myCkt= cktManager.designGetOptInternalCkt(design);
    litNO= offsetBase;
    muxedNets.clear();
    selLits.clear();
    muxOutLits.clear();
    outputCNF.clear();
    litNOMap.clear();
    // Add MUXes into nets first. Two numbers are reserved for a net.
    // The smaller one is "before MUX", and the large one is "after MUX".
    // Also build MUX here.
    for (netIter= allNets.begin(); netIter != allNets.end(); netIter++)
    {
        myNet= *netIter;
        myWire= myCkt->netGetSimWire(myNet);
        /* Note: In RTL error correction, I use output of MUX for resynthesis.
           So I will need to insert a MUX... But tie the select line to 0
        if (myWire->isPI)
        {
            // Don't insert buffer for PI because diagnosis that says
            // PI is wrong is meaningless
            // Still get new nets but add BUF between lit and lit+1
            currLit= getNetLitNO(myNet, litNOMap, litNO);
            clause.push_back(currLit);
            clause.push_back(-(currLit+1));
            outputCNF.push_back(clause);
            clause.clear();
            clause.push_back(-currLit);
            clause.push_back(currLit+1);
            outputCNF.push_back(clause);
            clause.clear();
            litNO+= 2;
            continue;
        }
        */
        // Insert a MUX
        currLit= getNetLitNO(myNet, litNOMap, litNO);
        genMux(currLit, currLit+2, currLit+3, currLit+1, muxCNF);
        outputCNF.insert(outputCNF.end(), muxCNF.begin(), muxCNF.end());
        muxedNets.push_back(myNet);
        selLits.push_back(currLit + 3);
        muxOutLits.push_back(currLit + 1);
        if (myWire->isPI)
        {
            // Set select 0
            clause.push_back(-(currLit + 3));
            outputCNF.push_back(clause);
            clause.clear();
        }
        litNO+= 4;
    }
    offsetBase= litNO;
    genInstsExec(design, offsetBase, insts, outputCNF, litNOMap, largestLitNO,
                 1);
}


// *****************************************************************************
// genInsts()
//
/// \brief Generates CNF circuit for insts. 
/// Note that litNOMap and outputCNF will not be cleared before use.
///
/// \param design oaDesign
/// \param offsetBase The offset of literal numbers used in this CNF. 
/// \param insts The instances that need to generate SAT-CNF.
/// \param outputCNF Returns the generated CNF
/// \param litNOMap Returns a map from oaNet to literal number
/// \param largestLitNO The largest litNO used in the generated CNF.
//
// *****************************************************************************
void satCkt::genInsts(oaDesign* design, int offsetBase, 
                      vector<oaInst*>& insts, 
                      clauseType& outputCNF, map<oaNet*, int>& litNOMap, 
                      int& largestLitNO)
{
    litNOMap.clear();
    outputCNF.clear();
    genInstsExec(design, offsetBase, insts, outputCNF, litNOMap, largestLitNO,
                 0);
}

/// Generate SAT instances. If isw= 1: assume MUX has been inserted. 
void satCkt::genInstsExec(oaDesign* design, int offsetBase, 
                          vector<oaInst*>& insts, 
                          clauseType& outputCNF, map<oaNet*, int>& litNOMap, 
                          int& largestLitNO, int isw)
{
    vector<oaInst*>::iterator instIter;
    vector<int>::iterator intIter;
    vector<int> inputs, outputs, extras;
    vector<int> clause;
    optSimInst* myInst;
    libSATCell* mySatCell;
    libCell* myCell;
    int litNO;
    unsigned int i;
    clauseType::iterator clauseIter;
    optInternalCkt* myCkt;

    myCkt= cktManager.designGetOptInternalCkt(design);
    litNO= offsetBase;
    for (instIter= insts.begin(); instIter != insts.end(); instIter++)
    {
        inputs.clear();
        outputs.clear();
        myInst= myCkt->instGetSimInst(*instIter);
        myCell= myInst->cell;
        mySatCell= myCell->satCell;
        for (i= 0; i < myCell->inputNo; i++)
        {
            optSimPort* myPort= myInst->inputs[i];
            optSimWire* myWire= myPort->wire1;
            oaNet* myNet;
            if (myWire)
                myNet= myWire->origWire;
            else
                myNet= NULL;
            if (myNet == NULL)
                inputs.push_back(litNO++);
            else
            {
                // If MUX has been inserted, inputs actually connect to
                // net literal + 1
                if (isw == 1)
                    inputs.push_back(getNetLitNO(myNet, litNOMap, litNO)+1);
                else
                    inputs.push_back(getNetLitNO(myNet, litNOMap, litNO));
            }
        }
        for (i= 0; i < myInst->outputs.size(); i++)
        {
            optSimPort* myPort= myInst->outputs[i];
            optSimWire* myWire= myPort->wire1;
            oaNet* myNet;
            if (myWire)
                myNet= myWire->origWire;
            else
                myNet= NULL;
            if (myNet == NULL)
                outputs.push_back(litNO++);
            else
                outputs.push_back(getNetLitNO(myNet, litNOMap, litNO));
        }
        // Now build CNF
        // Build extras
        extras.clear();
        for (i= 0; i < mySatCell->extraNO; i++)
        {
            extras.push_back(litNO++);
        }
        for (clauseIter= mySatCell->clauses.begin(); clauseIter != mySatCell->clauses.end(); clauseIter++)
        {
            for (intIter= (*clauseIter).begin(); intIter != (*clauseIter).end(); intIter++)
            {
                int sign;
                unsigned int int2;
                sign= *intIter > 0 ? 1 : -1;
                int2= *intIter > 0 ? *intIter : -(*intIter);
                if (int2 <= mySatCell->inputNO)
                    clause.push_back(inputs[int2 - 1] * sign);
                else if (int2 <= (mySatCell->inputNO + mySatCell->outputNO))
                    clause.push_back(outputs[int2 - mySatCell->inputNO - 1] * sign);
                else
                    clause.push_back(extras[int2 - mySatCell->inputNO - mySatCell->outputNO - 1] * sign);
            }
            outputCNF.push_back(clause);
            clause.clear();
        }
    }
    largestLitNO= litNO - 1;
}

// *****************************************************************************
// getNetLitNO()
//
/// \brief Gets the literal number of a net. 
///
/// If not existed in the map, use litNO and add it to the map. litNO will increase by 1 in this case.
//
// *****************************************************************************
int satCkt::getNetLitNO(oaNet* inWire, map<oaNet*, int>& litNOMap, int& litNO)
{
    map<oaNet*, int>::iterator mapIter, mapIter2;
    oaBitNet* bitNet= static_cast<oaBitNet*>(inWire);

    mapIter= litNOMap.find(inWire);
    if (mapIter == litNOMap.end())
    {
        // Need to check if equivalent net exist. If so, use that one.
        if (bitNet->getEquivalentNets().getCount() > 0)
        {
            oaBitNet*  bitNet2;
            oaIter<oaBitNet>   oaBNIter(bitNet->getEquivalentNets());
            for (bitNet2 = oaBNIter.getNext(); bitNet2; bitNet2 = oaBNIter.getNext())
            {
                mapIter2= litNOMap.find(bitNet2);
                if (mapIter2 != litNOMap.end())
                {
                    return (*mapIter2).second;
                }
            }
            // The first one, set all equivalent nets to this one
            {
                oaBitNet*  bitNet2;
                oaIter<oaBitNet>   oaBNIter(bitNet->getEquivalentNets());
                for (bitNet2 = oaBNIter.getNext(); bitNet2; bitNet2 = oaBNIter.getNext())
                {
                    litNOMap[bitNet2]= litNO;
                }
            }
        }
        litNOMap[inWire]= litNO++;
        return (litNO - 1);
    }
    else
        return (*mapIter).second;
}


} // End of namespace
// vim: ci et sw=4 sts=4
