/*

Example code for using the CausalStateAnalyzer class. See the
corresponding file for what it does. This programs defines and
runs 1D elementary Cellular Automata as an example.

Notes:
- Define NO_SDL to compile this example if you don't have the SDL for graphics. The main algorithm is independent from SDL.
- The graphics mode is the default, press Escape to quit.
- A batch mode is available with -b, a quiet mode with -q.
- ./CA_Complexity -h gives the usage help.
- The images are automatically saved (unless -d option) even if SDL is available.

Nicolas Brodu, 04/05/07, file version 1.0
Released under GNU Public Licence v2 or above.

See http://nicolas.brodu.free.fr/en/programmation/mesincom/index.html
for more information and possibly updates.

*/


#include <iostream>
#include <fstream>
#include <sstream>
using namespace std;
#include "CausalStateAnalyzer.h"

#include <stdlib.h>
#ifndef NO_SDL
#include <SDL/SDL.h>
#endif

// wrapping vector in each direction, up to limited distance (avoid use of % operator)
template<typename T>
struct WrappingVector : public vector<T> {
    WrappingVector(int size) : vector<T>(size) {}
    T& operator[](int idx) {
        int size = this->size();
        if (idx<0) idx += size;
        else if (idx>=size) idx -= size;
        assert(idx>=0); assert(idx<size);
        return vector<T>::operator[](idx);
    }
};

struct AutomataWorld : public WrappingVector<int> {
    AutomataWorld() : WrappingVector<int>(worldSize) {}
    static int worldSize;
};
int AutomataWorld::worldSize;


// For past and future cones. Cell states at t, then at t+-1, t+-2, etc...
struct LightCone : public vector<int> {

    // map cones to id so distributions only manipulate the id
    // distributions are possible with the cones directly too, thanks to templatized
    // analyzer, but this is faster
    typedef map<LightCone,int> LightConeMap;

    // Implement a lesser_than operator to get O(log(n)) insert on maps
    bool operator<(const LightCone& cone) {
        // use simple lexicographic ordering
        const int minsize = min(size(), cone.size());
        for (int i=0; i<minsize; ++i) {
            if ((*this)[i] < cone[i]) return true;
            if ((*this)[i] > cone[i]) return false;
        }
        // elements are equal so far, strict < iff second cone has more elements
        return size() < cone.size();
    }

    static LightConeMap conesId;

    // returns the id of this light cone, make a new one if necessary the first time
    int getId() {
        LightConeMap::iterator myPlace = conesId.find(*this);
        if (myPlace != conesId.end()) return myPlace->second;
        return conesId[*this] = ++nextId; // start from 1
    }

    static int nextId;
};
int LightCone::nextId = 0;
LightCone::LightConeMap LightCone::conesId;

struct Rule {

    int nextState[8];

    Rule(unsigned int number) {
        // convert CA rule notation to state array
        for (int i=0; i<8; ++i) {
            nextState[i] = number & 1;
            number = number >> 1;
        }
    }

    // apply the rule to the 3 cells that are either 0 or 1
    int apply(int left, int center, int right) {
        return nextState[ (left << 2) | (center << 1) | right ];
    }
};

// to generate the initial conditions
struct RandFunctor {
    int operator()() {
        return (rand() >> 17) & 1; // use any bit
    }
};

#ifndef NO_SDL
SDL_Surface *screen;
void drawGrayPixel(int x, int y, int gray) {
    *((int*)screen->pixels + y*screen->pitch/4 + x) = SDL_MapRGB(screen->format, gray, gray
, gray);
}
#else
void drawGrayPixel(int x, int y, int gray) {}
#endif

int main(int argc, char** argv) {

    // User-defined variables
    int past_depth = 4;
    int future_depth = 3;
    int max_steps = 500;
    int skip_steps = 0;
    int rule_number = -1;
    AutomataWorld::worldSize = 500;
    int seed = 1;
    float siglevel = 0.04321;

#ifndef NO_SDL
    bool batchMode = false;
#else
    bool batchMode = true;
#endif
    bool quietMode = false;
    bool zMode = false;
    bool oneByOne = false;
    bool noninc = false;
    bool savegraphics = true;
    bool scaleComplexity = false;
    int c; opterr = 0;
    while ((c=getopt(argc,argv,"zdn1hbcqk:p:f:t:w:s:r:g:"))!=-1) switch(c) {
        case 'd': savegraphics = false; break;
        case 'n': noninc = true; break;
        case 'c': scaleComplexity = true; break;
        case '1': oneByOne = true; break;
        case 'b': batchMode = true; break;
        case 'q': quietMode = true; break;
        case 'z': quietMode = true; zMode=true; break;
        case 'g': {
            float arg = (float)atof(optarg);
            if (arg!=0.0f) siglevel = arg;
            break;
        }
        case 'k': {
            skip_steps = atoi(optarg);
            break;
        }
        case 'p': {
            int arg = atoi(optarg);
            if (arg!=0) past_depth=arg;
            break;
        }
        case 'f': {
            int arg = atoi(optarg);
            if (arg!=0) future_depth=arg;
            break;
        }
        case 't': {
            int arg = atoi(optarg);
            if (arg!=0) max_steps=arg;
            break;
        }
        case 's': {
            int arg = atoi(optarg);
            if (arg!=0) seed=arg;
            break;
        }
        case 'w': {
            int arg = atoi(optarg);
            if (arg!=0) AutomataWorld::worldSize=arg;
            break;
        }
        case 'r': {
            rule_number=atoi(optarg); // 0 is valid
            break;
        }
        case '?':
        case 'h':
        default:
            cout << "Usage:\n-p<integer>: past depth\n-f<integer>: future depth\n-t<integer>: max time\n-w<integer>: world width\n-r<integer>: rule number\n-k<integer>: Skip this number of steps in addition to the minimum of (past_depth+future_depth-1)\n-g<float>: chi-square significance level for clustering distributions\n-b: batch mode\n-q: quiet mode\n-z: really quiet (no output)\n-s<integer>: random seed\n-d: don't save graphics\n-1: process observations one by one (default is row by row)\n-n: Non-incremental algorithm\n-c: Scale complexities using theoretical max so the result can be compared across different rules.\nAll options may be combined except -1 and -n." << endl;
            return 0;
    }

    // derived variables
    const int timeSize = past_depth + future_depth;

    if (rule_number==-1) rule_number = 110; // default

    if (noninc && oneByOne) {
        cout << "Options -1 and -n are incompatible." << endl;
        return 0;
    }

    if (!quietMode) {
        std::ios::sync_with_stdio(false); // GNU specific?
#ifndef NO_SDL
        cout << (batchMode?"batch":"graphics") << " mode" << endl;
#else
        cout << "Batch mode forced (graphics not compiled in, please use SDL)." << endl;
#endif
        cout << "past depth = " << past_depth << endl;
        cout << "future depth = " << future_depth << endl;
        cout << "max time = " << max_steps << endl;
        cout << "world size = " << AutomataWorld::worldSize << endl;
        cout << "rule number = " << rule_number << endl;
        cout << "random seed = " << seed << endl;
        cout << "significance for clustering = " << siglevel << endl;
        cout << "Processing observations " << (oneByOne?"one by one.":(noninc?"non-incrementally":"row by row.")) << endl;
    }

    srand(seed);

#ifndef NO_SDL
    if (!batchMode) {
        // Initialize the SDL library
        if( SDL_Init(SDL_INIT_VIDEO) < 0 ) {
            cerr << "Couldn't initialize SDL: " << SDL_GetError() << endl;
            return 1;
        }

        // Clean up on exit
        atexit(SDL_Quit);

        // Initialize the display in a 32-bit mode,
        screen = SDL_SetVideoMode(AutomataWorld::worldSize*2, max_steps, 32, SDL_HWSURFACE);
        if (!screen) {
            cerr << "Couldn't initialize SDL: " << SDL_GetError() << endl;
            return 1;
        }
    }
#endif

    ofstream filteredImage, rawImage;

    if (savegraphics) {
        // create image files
        stringstream ss; ss << "rule" << rule_number << "_raw_" << AutomataWorld::worldSize << "x" << max_steps << ".pgm";
        rawImage.open(ss.str().c_str(), std::ios::binary);
        if (!quietMode) cout << "Saving raw and filtered images in " << ss.str().c_str() << " and " << flush;

        stringstream ss2; ss2 << "rule" << rule_number << "_complexity_p" << past_depth << "f" << future_depth << "_" << AutomataWorld::worldSize << "x" << max_steps << ".pgm";
        filteredImage.open(ss2.str().c_str(), std::ios::binary);
        if (!quietMode) cout << ss2.str().c_str() << endl;

        stringstream pgmheader; pgmheader << "P5 " << AutomataWorld::worldSize << " " << max_steps << " 255\n";
        rawImage << pgmheader.str();
        filteredImage << pgmheader.str();
    }

    if (!quietMode) {
        if (!savegraphics) cout << "Don't saving graphics" << endl;
        cout << endl;
    }

    // Create the Cellular Automata
    WrappingVector<AutomataWorld> spaceTime(timeSize);
    Rule rule(rule_number);

    // random initial condition
    generate(spaceTime[0].begin(), spaceTime[0].end(), RandFunctor());

    // Use light cones Id for best perf, though using light cones directly would work too
    CausalStateAnalyzer<int,int> analyzer(siglevel);

    int timeSlice = 0;
    int displayCount = 0;

    // fill the first steps where complexity can't be computed
    for (int t = 1; t < timeSize-1; ++t) {
        int lastSlice = timeSlice++;
        // compute new state
        for (int i=0; i<AutomataWorld::worldSize; ++i) {
            spaceTime[timeSlice][i] = rule.apply(
                spaceTime[lastSlice][i-1],
                spaceTime[lastSlice][i],
                spaceTime[lastSlice][i+1]);
        }
    }

    // skip extra steps at the user request
    for (int t = 0; t < skip_steps; ++t) {
        int lastSlice = timeSlice++;
        if (timeSlice >= timeSize) timeSlice = 0;
        // compute new state
        for (int i=0; i<AutomataWorld::worldSize; ++i) {
            spaceTime[timeSlice][i] = rule.apply(
                spaceTime[lastSlice][i-1],
                spaceTime[lastSlice][i],
                spaceTime[lastSlice][i+1]);
        }
    }

    // Now feed the light cones to the analyzer
    if (!quietMode) cout << "Feeding light cones" << endl;

    vector<int> cindices;
    if (noninc) cindices.resize(max_steps * AutomataWorld::worldSize);

    for (int t = 0; t < max_steps; ++t) {
        if (!quietMode) {
            if (displayCount++==0) cout << "CA steps /" << max_steps << ": ";
            cout << (t+1) << " " << flush;
            if (displayCount>=10) {
                displayCount = 0;
                cout << endl;
            }
        }

        if (!noninc) cindices.resize(AutomataWorld::worldSize);

        // build past/future light cones for each points at future_depth in the past
        int present = timeSlice - future_depth + 1;
        // do this for each point
        for (int i=0; i<AutomataWorld::worldSize; ++i) {
            char b = spaceTime[present][i] * 255;
            if (savegraphics) rawImage.write(&b, 1);
            if (!batchMode) drawGrayPixel(i, t, b);

            // Past Light Cone = plc
            LightCone plc;
            for (int time = 0; time < past_depth; ++time) {
                // CA rule has speed light of 1 space/time unit = 1 cell/step
                // So, each step, space delta = time * 1 = time
                for (int space = i-time; space <= i+time; ++space) {
                    plc.push_back(spaceTime[present-time][space]);
                }
            }

            // Future Light Cone = flc
            LightCone flc;
            for (int time = 0; time < future_depth; ++time) {
                for (int space = i-time; space <= i+time; ++space) {
                    flc.push_back(spaceTime[present+time][space]);
                }
            }

            // One more realization, all cells have the same rule
            int cidx = i + (noninc?(AutomataWorld::worldSize*t):0);
            cindices[cidx] = analyzer.addObservation(plc.getId(), flc.getId());

            if (oneByOne) {
                analyzer.commitObservations();

                vector<float>& cmplx = analyzer.getScaledComplexities();
                char grayScale = 255 - (unsigned char)(255.99f * cmplx[cindices[i]]);
                if (!batchMode) drawGrayPixel(AutomataWorld::worldSize+i, t, grayScale);

                if (savegraphics) filteredImage.write(&grayScale, 1);
            }

        }

#ifndef NO_SDL
        if (!batchMode) {
            SDL_Event event;
            SDL_PollEvent(&event);
            switch (event.type) {
                case SDL_KEYDOWN: if (event.key.keysym.sym!=SDLK_ESCAPE) break;
                case SDL_QUIT: return 0;
            }
        }
#endif

        if ((!oneByOne) && (!noninc)) {
            analyzer.commitObservations();
            vector<float>& cmplx = analyzer.getScaledComplexities();

            float scaleFactor = 255.99f;
            if (scaleComplexity) scaleFactor *= CausalStateAnalyzer<int,int>::my_log2f(analyzer.clusters.size()) / CausalStateAnalyzer<int,int>::my_log2f(analyzer.observedDistributions.size());

            for (int i=0; i<AutomataWorld::worldSize; ++i) {

                char grayScale = 255 - (unsigned char)(scaleFactor * cmplx[cindices[i]]);
                if (!batchMode) drawGrayPixel(AutomataWorld::worldSize+i, t, grayScale);

                if (savegraphics) filteredImage.write(&grayScale, 1);
            }
        }

#ifndef NO_SDL
        if (!batchMode) SDL_UpdateRect(screen, 0, t, AutomataWorld::worldSize*2, 1);
#endif


        int lastSlice = timeSlice++;
        if (timeSlice >= timeSize) timeSlice = 0;
        // compute new state
        for (int i=0; i<AutomataWorld::worldSize; ++i) {
            spaceTime[timeSlice][i] = rule.apply(
                spaceTime[lastSlice][i-1],
                spaceTime[lastSlice][i],
                spaceTime[lastSlice][i+1]);
        }

    }

    if (noninc) {
        analyzer.commitObservations();
        vector<float>& cmplx = analyzer.getScaledComplexities();
        float scaleFactor = 255.99f;
        if (scaleComplexity) scaleFactor *= CausalStateAnalyzer<int,int>::my_log2f(analyzer.clusters.size()) / CausalStateAnalyzer<int,int>::my_log2f(analyzer.observedDistributions.size());
        for (int t = 0; t < max_steps; ++t) for (int i=0; i<AutomataWorld::worldSize; ++i) {
            char grayScale = 255 - (unsigned char)(scaleFactor * cmplx[cindices[i+AutomataWorld::worldSize*t]]);
            if (!batchMode) drawGrayPixel(AutomataWorld::worldSize+i, t, grayScale);
            if (savegraphics) filteredImage.write(&grayScale, 1);
        }
#ifndef NO_SDL
        if (!batchMode) SDL_UpdateRect(screen, 0, 0, AutomataWorld::worldSize*2, max_steps);
#endif
    }

    if (savegraphics) {
        rawImage.close();
        filteredImage.close();
    }

    if (!quietMode) {
        cout << "Global complexity: " << analyzer.getGlobalComplexity() << endl;
        cout << "Number of states: " << analyzer.clusters.size() << endl;
        cout << "Number of observed pasts: " << analyzer.observedDistributions.size() << endl;
#ifdef MESINCOM_DEBUG
        for (unsigned int i=0; i<analyzer.clusters.size(); ++i) {
            int nf = analyzer.clusters[i]->distribution.size();
            cout << "nf="<<nf<<"("<<analyzer.clusters[i]->distribution.total<<"): ";
            for (int f=0; f<nf;++f) {
                int flcid = analyzer.clusters[i]->distribution[f].first;
                int cf = analyzer.clusters[i]->distribution[f].second;
                cout <<flcid;
                cout << "("<<cf<<") ";
            }
            cout << endl;
        }
        analyzer.assertNoClusterMatch();
#endif
    }
    else if (!zMode) cout << "Rule=" << rule_number << ", complexity=" <<  analyzer.getGlobalComplexity() << endl;

#ifndef NO_SDL
    if (!batchMode) while (1) {
        SDL_Event event;
        SDL_WaitEvent(&event);
        switch (event.type) {
            case SDL_KEYDOWN: if (event.key.keysym.sym!=SDLK_ESCAPE) break;
            case SDL_QUIT: return 0;
        }
    }
#endif

    return 0;
}
