#include <cstdlib>
#include <iostream>
#include <ctime>
#include <vector>
#include <set>
#include "oaDesignDB.h"
#include "oagFuncOccGraph.h"
#include "oagFuncPrint.h"
#include "oagFuncQueryAi.h"
#include "oagFuncQueryOcc.h"
#include "oagSswAiGraphUtil.h"
#include "oagSswAiReasoningEngine.h"
#include "oagSswObsAiGraph.h"
#include "oagSswOdcSatSweepingEngine.h"
#include "oagSswSatSweepingEngine.h"
#include "oagUtilOptionParser.h"

using namespace std;
using namespace oa;

// *****************************************************************************
class Frame {
  public:
    vector<oagAi::Ref> inputs;
    vector<oagAi::Ref> outputs;
    vector<oagAi::Ref> states;
    vector<oagAi::Ref> nextStates;
};

// *****************************************************************************
/// \brief Replaces references to merged nodes with references to the nodes they
/// were merged to.
template<class T>
void
substituteReps(vector<oagAi::Ref> &refs,
               T                  &engine)
{
    for (oaUInt4 i = 0; i < refs.size(); ++i) {
        oagAi::Ref x = refs[i];
        if (engine.hasRep(x)) {
            refs[i] = engine.getRep(x);
        }
    }
}

// *****************************************************************************
/// \brief Copies the rep relationships of \a nodes from \a engine to \a reps.
template<class T>
void
getRepMap(const list<oagAi::Ref>        &nodes,
          T                             &engine,
          map<oagAi::Ref, oagAi::Ref>   &reps)
{
    using namespace oagAi;

    oaUInt4 counter = 0;
    for (list<Ref>::const_iterator it = nodes.begin(); it != nodes.end(); ++it) {
        Ref x = *it;
        if (engine.hasRep(x)) {
            reps[x] = engine.getRep(x);
            counter++;
        }
    }
    cout << "merged nodes in search nodes " << counter << endl;
}

// *****************************************************************************
oaUInt4
getNumUnmergedAnds(oagAi::Graph                      &graph,
                   const map<oagAi::Ref, oagAi::Ref> &merged)
{
    using namespace oagAi;

    oaUInt4   count = 0;
    list<Ref> allNodes;
    graph.getAll(allNodes);
    for (list<Ref>::iterator it = allNodes.begin(); it != allNodes.end(); ++it) {
        Ref x = graph.getNonInverted(*it);

        if (graph.getNodeType(x) == Node::AND
            && merged.find(x) == merged.end()) {
            ++count;
        }
   }
    return count;
}

// *****************************************************************************
void
unrollAndSweep(oaDesign              *design,
               oagUtil::OptionParser &options)
{
    using namespace oagAi;
    using namespace oagFunc;
    using namespace oagSsw;

    oagUtil::Option *seedOpt = options.getOption("randomSeed");
    if (seedOpt->isGiven()) {
        oaUInt4 seed = atoi(seedOpt->getValue());
        srand(seed);
    } else {
        srand(time(0));
    }

    oaOccurrence * occ = design->getTopOccurrence();
    list<OccRef> inputs, outputs, states;
    OccGraph::getInputs(occ, inputs);
    OccGraph::getOutputs(occ, outputs);
    OccGraph::getStates(occ, states);

    ObsAiGraph graph;
    graph.enableStructuralHashing();
    AiReasoningEngine aiEngine(graph);

    oaUInt4             numFrames = atoi(options.getOption("unroll")->getValue()) + 1;
    vector<Frame>       frames(numFrames);

    // Build frame 0.
    QueryOcc query0(design, &aiEngine);
    for (list<OccRef>::iterator it = inputs.begin(); it != inputs.end(); ++it) {
        GenericFunction f = aiEngine.genericNewTerminal();
        Ref             x = aiEngine.getRef(f);

        query0.set(*it, f);
        frames[0].inputs.push_back(x);
    }
    for (list<OccRef>::iterator it = states.begin(); it != states.end(); ++it) {
        GenericFunction f = aiEngine.genericNewTerminal();
        Ref             x = aiEngine.getRef(f);

        query0.set(*it, f);
        frames[0].states.push_back(x);
    }
    
    for (list<OccRef>::iterator it = outputs.begin(); it != outputs.end(); ++it) {
        OccRef output = *it;
        Ref    x      = aiEngine.getRef(query0.get(output));

        frames[0].outputs.push_back(x);
        graph.setObservable(x, true);
    }
    for (list<OccRef>::iterator it = states.begin(); it != states.end(); ++it) {
        OccRef state = *it;
        Ref    x     = aiEngine.getRef(query0.getNextState(state));

        frames[0].nextStates.push_back(x);
        graph.setObservable(x, true);
    }

    set<Ref> prevTfi;
    map<Ref, Ref> reps;
    // Interleave sweeping and unrolling.
    for (oaUInt4 i = 0; i < numFrames; ++i) {
        cout << "Frame " << i << ":" << endl;

        // Run SAT sweeping.
        list<Ref> observable;
        graph.getAllObservable(observable);

        list<Ref> tfi;
        AiGraphUtil::getTransitiveFanin(graph, observable, tfi);
        tfi.insert(tfi.end(), observable.begin(), observable.end());
        cout << "tfi size " << tfi.size() << endl;
        
        list<Ref> searchNodes;
        for (list<Ref>::iterator it = tfi.begin();
             it != tfi.end();
             it++) {
            if (prevTfi.find(*it) == prevTfi.end()) {
                searchNodes.push_back(*it);
                prevTfi.insert(*it);
            }
        }
        cout << "search node size " << searchNodes.size() << endl;
        
        if (options.getOption("obsLevels")->isGiven()) {
            //OdcSatSweepingEngine sswEngine(graph, tfi);
            OdcSatSweepingEngine sswEngine(graph, searchNodes);
            sswEngine.getLogger()->enable();

            oaUInt4 numObsLevels = atoi(options.getOption("obsLevels")->getValue());
            sswEngine.run(numObsLevels);

            substituteReps(frames[i].inputs, sswEngine);
            substituteReps(frames[i].outputs, sswEngine);
            substituteReps(frames[i].states, sswEngine);
            substituteReps(frames[i].nextStates, sswEngine);

            getRepMap(tfi, sswEngine, reps);
        } else {
            SatSweepingEngine sswEngine(graph, tfi);
            sswEngine.getLogger()->enable();

            sswEngine.run();

            substituteReps(frames[i].inputs, sswEngine);
            substituteReps(frames[i].outputs, sswEngine);
            substituteReps(frames[i].states, sswEngine);
            substituteReps(frames[i].nextStates, sswEngine);

            getRepMap(tfi, sswEngine, reps);
        }

        // This is where we would check the property if we had one.
        cout << getNumUnmergedAnds(graph, reps) << " unmerged AND nodes in graph" << endl;

        if (i < numFrames - 1) {
            // Clear observability of states which are not also primary outputs.
            for (list<Ref>::iterator it = observable.begin();
                 it != observable.end();
                 ++it) {
                graph.setObservable(*it, false);
            }

            // Restore observability of primary outputs in all frames.
            for (oaUInt4 j = 0; j <= i; ++j) {
                for (oaUInt4 k = 0; k < frames[j].outputs.size(); ++k) {
                    Ref output = frames[j].outputs[k];
                    graph.setObservable(output, true);
                }
            }

            // Unroll.
            QueryAi query(&graph, &aiEngine);

            for (oaUInt4 j = 0; j < i; ++j) {
                for (oaUInt4 k = 0; k < frames[j].inputs.size(); ++k) {
                    Ref x = frames[j].inputs[k];
                    Ref y = frames[j + 1].inputs[k];

                    query.set(x, aiEngine.getFunc(y));
                }
                for (oaUInt4 k = 0; k < frames[j].states.size(); ++k) {
                    Ref x = frames[j].states[k];
                    Ref y = frames[j + 1].states[k];

                    query.set(x, aiEngine.getFunc(y));
                }
            }

            // New inputs for frame i + 1.
            for (oaUInt4 j = 0; j < frames[i].inputs.size(); ++j) {
                Ref x = frames[i].inputs[j];
                Ref y = graph.newTerminal(graph.getNull());

                query.set(x, aiEngine.getFunc(y));
                frames[i + 1].inputs.push_back(y);
            }
            // Propagate state from i to i + 1.
            for (oaUInt4 j = 0; j < frames[i].states.size(); ++j) {
                Ref s = frames[i].states[j];
                Ref ns = frames[i].nextStates[j];

                query.set(s, aiEngine.getFunc(ns));
                frames[i + 1].states.push_back(ns);
            }

            for (oaUInt4 j = 0; j < frames[i].outputs.size(); ++j) {
                Ref x = frames[i].outputs[j];
                Ref y = aiEngine.getRef(query.get(x));

                frames[i + 1].outputs.push_back(y);
                graph.setObservable(y, true);
            }
            for (oaUInt4 j = 0; j < frames[i].nextStates.size(); ++j) {
                Ref x = frames[i].nextStates[j];
                Ref y = aiEngine.getRef(query.get(x));

                frames[i + 1].nextStates.push_back(y);
                graph.setObservable(y, true);
            }
        }
    }
}

// *****************************************************************************
// main
// *****************************************************************************
int main(int argc, const char *argv[]) {
    using namespace oagUtil;

    OptionParser options("Run ODC SAT sweeping on a design.");
    Option *libOpt          = options.add("lib", "Input library name", true, "LIB");
    Option *cellOpt         = options.add("cell", "Input cell name", true, "CELL");
    Option *viewOpt         = options.add("view", "Input view name (default: netlist)", false, "VIEW");
    Option *obsLevelsOpt    = options.add("obsLevels", "Number of levels of observability", false, "NUM");
    Option *unrollOpt       = options.add("unroll", "Number of times to unroll", true, "NUM");
    Option *libDefFileOpt   = options.add("libDefFile", "Use a specific library definitions file", false, "FILE");
    Option *seedOpt         = options.add("randomSeed", "Seed for random number generator", false, "NUM");

    // Some options are not used in this function.
    (void) obsLevelsOpt;
    (void) unrollOpt;

    if (!options.parse(argc, argv)) {
        cerr << options.getMessage();
        exit(EXIT_FAILURE);
    }

    try {
        oaDesignInit();
        if (libDefFileOpt->isGiven()) {
            oaLibDefList::openLibs(libDefFileOpt->getValue());
        } else {
            oaLibDefList::openLibs();
        }

        oagFunc::initialize();

        if (seedOpt->isGiven()) {
            srand(atoi(seedOpt->getValue()));
        } else {
            srand(time(0));
        }

        const oaVerilogNS  ns;
        const oaScalarName libName(ns, libOpt->getValue());
        const oaScalarName cellName(ns, cellOpt->getValue());
        const oaScalarName viewName(ns, viewOpt->isGiven() ? viewOpt->getValue()
                                                           : "netlist");

        oaLib *lib = oaLib::find(libName);
        if (!lib) {
            cerr << "Error: Library not found: " << libOpt->getValue() << endl;
            exit(EXIT_FAILURE);
        }

        oaDesign *design = oaDesign::open(libName, cellName, viewName, 'a');

        printFunctionality(design, false, true);
        unrollAndSweep(design, options);
    } catch(oaException &e) {
        cerr << "OA Exception: " << e.getMsg() << endl;
        exit(EXIT_FAILURE);
    }   

    return 0;
}
