/*

Code extending the causal state reconstruction algorithm
technique exposed in "Automatic Filters for the Detection of Coherent
Structure in Spatiotemporal Systems", by Shalizi et al.

The algorithm is detailed and explained in pseudo-code in Section 5.1
of my PhD dissertation, and a usage example in the article "Quantifying
the effect of learning on recurrent spiking neurons". These documents
are available on my web page, see the link below.

This algorithm can handle incremental estimation of the states as light
cones become available, as well as remove past observations.

The CausalStateAnalyzer class is generic and will process your custom types.
Feed it light cones incrementally as you produce them, and it will
update the causal states. You can then ask for what are the best state
estimates so far.

See the README file for a quick API guide.
See CA_Complexity.cpp for a usage example.

Nicolas Brodu, 04/05/07, file version 1.0
Released under GNU Public Licence v2 or above.

See http://nicolas.brodu.free.fr/en/programmation/mesincom/index.html
for more information and possibly updates.

*/

#include <list>
#include <map>
// unordered_map will soon become standard C++, but it's in the Technical Report 1 namespace at the time of writing
#ifndef NO_TR1
#include <tr1/unordered_map>
#endif
#include <vector>
#include <algorithm>

// WARNING: Serialization is back to experimental status: Untested and very probably broken since 0.7.
// Including boost is optional, this creates a linking dependency to the
// serialization lib, whereas this file is otherwise standalone
// Note 1: it may be necessary to define NO_TR1 too
#ifdef MESINCOM_SERIALIZE
#include <boost/serialization/base_object.hpp>
#include <boost/serialization/list.hpp>
#include <boost/serialization/map.hpp>
#include <boost/serialization/vector.hpp>
#endif

#include <assert.h>
#include <math.h>
#include <stdlib.h>

namespace CausalStateAnalyzerDetail {
    struct RandomFunctor {
        int operator()(int N) {
            return rand() % N;
        }
#ifdef MESINCOM_SERIALIZE
        template<class Archive> void save(Archive & ar, const unsigned int version) const {
            unsigned int seed;
            rand_r(&seed);
            ar & seed;
        }
        template<class Archive> void load(Archive & ar, const unsigned int version) {
            unsigned int seed;
            ar & seed;
            srand(seed);
        }
        BOOST_SERIALIZATION_SPLIT_MEMBER()
#endif
    };
}

template<class PastLightCone, class FutureLightCone, typename RandomFunctor = CausalStateAnalyzerDetail::RandomFunctor>
struct CausalStateAnalyzer {

// IMPLEMENTATION, see user API below

    float matchLevel;

    RandomFunctor randfun;

    struct Cluster;
    struct CountedPtr {
        Cluster* cluster;
        int count;
        CountedPtr(Cluster* c); // forward ref, see Cluster below
        ~CountedPtr();          // forward ref, see Cluster below
        inline void ref() {++count;}
        inline void unref() {if (--count==0) delete this;}
    };

    // Cluster* don't work since the ptr is invalidated when the cluster is fusioned
    // => Shared ptr semantics on a Cluster*, so we can change the cluster for all shared copies
    protected:
    struct ClusterPtr {
        CountedPtr* indirect;
        inline ClusterPtr() : indirect(0) {}
        inline ClusterPtr(Cluster* cluster) : indirect(new CountedPtr(cluster)) {}
        inline ClusterPtr(const ClusterPtr& id) : indirect(id.indirect) {if (indirect) indirect->ref();}
        inline ClusterPtr& operator=(const ClusterPtr& id) { if (indirect) indirect->unref(); indirect = id.indirect; if (indirect) indirect->ref(); return *this;}
        inline ~ClusterPtr() {if (indirect) {indirect->unref(); indirect=0;}}
        inline void set(Cluster* cluster); // special... see below
        inline void setnew(Cluster* cluster) {
            assert(cluster);
            if (indirect) indirect->unref();
            indirect = new CountedPtr(cluster);
        }
        inline Cluster* get() {if (indirect) return indirect->cluster; return 0;}
        inline Cluster* operator->() {assert(indirect && indirect->cluster); return indirect->cluster;}
        inline bool empty() {return get()==0;}
    };
    public:


    // for delta obs holders
    // - insertion/removal is more important than traversal
    // - Sorted property makes future insert into a FutureDistribution O(N)
    // => map
    struct DeltaObsHolder : public std::map<FutureLightCone, int> {
        int total;
        int seqNum;
        DeltaObsHolder() : total(0), seqNum(-1) {}
        typedef typename std::map<FutureLightCone, int>::iterator iterator;
        typedef typename std::map<FutureLightCone, int>::const_iterator const_iterator;
#ifdef MESINCOM_SERIALIZE
        template<class Archive> void serialize(Archive & ar, const unsigned int version) {
            ar & boost::serialization::base_object< std::map<FutureLightCone, int> >(*this);
            ar & total & seqNum;
        }
#endif
    };

    /*  Holds a distribution over future cones
        It needs:
        - Sorted property
        - traversal (esp. for the match method).
        - insert/remove
        A map would provide O(n) traversal but the O is too big...
        Use a vector, at the cost of O(n) insert/remove, because we really need locality of reference for traversal.
    */
    struct FutureDistribution : public std::vector< std::pair<FutureLightCone, int> > {
        int total;
        // the cluster this distribution belongs to. May change behind our back when clusters fusion
        ClusterPtr cluster;

        typedef typename std::vector< std::pair<FutureLightCone, int> >::iterator iterator;
        typedef typename std::vector< std::pair<FutureLightCone, int> >::const_iterator const_iterator;

        inline FutureDistribution() : total(0) {}
        inline FutureDistribution(const DeltaObsHolder& obs) : total(0) {
            update(obs);
        }

        // add another distribution to this one.
        inline void add(const FutureDistribution& distribution) {
            const_iterator it2 = distribution.begin();
            const_iterator eit2 = distribution.end();
            iterator it1 = this->begin();
            iterator eit1 = this->end();
            std::vector< std::pair<FutureLightCone, int> > result;
            while (it1!=eit1) { // there is at least one element, assert(it1!=eit1 && it2!=eit2)
                if (it1->first < it2->first) {
                    result.push_back(*it1++);
                }
                else if (it2->first < it1->first) {
                    result.push_back(*it2++);
                    if (it2==eit2) break;
                }
                else {
                    it1->second += it2->second;
                    result.push_back(*it1);
                    ++it1; ++it2;
                    if (it2==eit2) break;
                }
            }
            // more elements remain in list 1 ?
            while (it1!=eit1) result.push_back(*it1++);
            // more elements remain in list 2 ?
            while (it2!=eit2) result.push_back(*it2++);
#ifdef MESINCOM_MEM_TIGHT
            *(std::vector< std::pair<FutureLightCone, int> >*)this = result;
#else
            this->swap(result);
#endif
            total += distribution.total;
        }

        // subtract another distribution from this one
        inline void sub(const FutureDistribution& distribution) {
            const_iterator it2 = distribution.begin();
            const_iterator eit2 = distribution.end();
            iterator it1 = this->begin();
            iterator eit1 = this->end();
            std::vector< std::pair<FutureLightCone, int> > result;
            while (it1!=eit1) { // there is at least one element, assert(it1!=eit1 && it2!=eit2)
                if (it1->first < it2->first) result.push_back(*it1++);
                else {
                    // user can't remove a dist that was not added, elements are equal
                    assert(it1->first == it2->first);
                    it1->second -= it2->second;
                    if (it1->second>0) result.push_back(*it1);
                    ++it1; ++it2;
                    if (it2==eit2) break;
                }
            }
            assert(it2==eit2);
            // more elements remain in list 1 ?
            while (it1!=eit1) result.push_back(*it1++);
#ifdef MESINCOM_MEM_TIGHT
            *(std::vector< std::pair<FutureLightCone, int> >*)this = result;
#else
            this->swap(result);
#endif
            total -= distribution.total;
            assert(total>=0); // user can't remove a distribution that wasn't added
            // total maybe 0, but then, the cluster will be destroyed
        }


        // update this distribution with new observations (added or subtracted)
        inline int update(const DeltaObsHolder& obs) {
            typename DeltaObsHolder::const_iterator it2 = obs.begin();
            typename DeltaObsHolder::const_iterator eit2 = obs.end();
            iterator it1 = this->begin();
            iterator eit1 = this->end();
            std::vector< std::pair<FutureLightCone, int> > result;
            if (it1!=eit1 && it2!=eit2) while(true) {
                if (it1->first < it2->first) {
                    result.push_back(*it1++);
                    if (it1==eit1) break;
                }
                else if (it2->first < it1->first) {
                    assert(it2->second>=0); // should have found an existing entry in list1 for neg instances
                    // a future may have been added and removed before commit, test for nullity
                    if (it2->second>0) result.push_back(*it2);
                    if (++it2==eit2) break;
                }
                else {
                    it1->second += it2->second;
                    assert(it1->second>=0);
                    if (it1->second>0) result.push_back(*it1);
                    ++it1; ++it2;
                    if (it1==eit1 || it2==eit2) break;
                }
            }
            // more elements remain in list 1 ?
            while (it1!=eit1) result.push_back(*it1++);
            // more elements remain in list 2 ?
            while (it2!=eit2) {
                assert(it2->second>=0); // should have found an existing entry in list1 for neg instances
                // a future may have been added and removed before commit, test for nullity
                if (it2->second>0) result.push_back(*it2);
                ++it2;
            }
#ifdef MESINCOM_MEM_TIGHT
            *(std::vector< std::pair<FutureLightCone, int> >*)this = result;
#else
            this->swap(result);
#endif
            total += obs.total;
            assert(total>=0);
            return total;
        }

        // Test if the 2 distributions match, perform a Chi-square computation.
        // Return true if chi-square result is above the given value
        bool match(const FutureDistribution& dist, float matchLevel) {
            // assert(dist1.total>0); assert(dist2.total>0);
            float n1 = total;
            float n2 = dist.total;
            float n1n2 = n1 * n2;
            float n1overn2 = n1 / n2;
            float n2overn1 = n2 / n1;
            float chisq = 0.0f;
            unsigned int ndegree = 0;

            // create a bin for each distinct entry in either of the distributions
            // use the fact that the list is sorted
            const_iterator it1 = this->begin();
            const_iterator eit1 = this->end();
            const_iterator it2 = dist.begin();
            const_iterator eit2 = dist.end();
            while (true) { // there is at least one element, assert(it1!=eit1 && it2!=eit2)
                // in each test case when applicable: it1->second!=0, it2->second!=0, by construction
                ++ndegree;
                if (it1->first < it2->first) {
                    chisq += n2overn1 * (it1++)->second;
                    if (it1==eit1) break;
                }
                else if (it2->first < it1->first) {
                    chisq += n1overn2 * (it2++)->second;
                    if (it2==eit2) break;
                }
                else {
                    float tmp = n2 * it1->second - n1 * it2->second;
                    chisq += tmp * tmp / (n1n2 * ((it1++)->second + (it2++)->second));
                    if (it1==eit1 || it2==eit2) break;
                }
            }
            // more elements remain in list 1 ?
            while (it1!=eit1) {
                chisq += n2overn1 * (it1++)->second;
                ++ndegree;
            }
            // more elements remain in list 2 ?
            while (it2!=eit2) {
                chisq += n1overn2 * (it2++)->second;
                ++ndegree;
            }

            // See below, igamma(a,0) = gamma(a), then normalize, nig = 1.0f
            if (chisq<1e-6f) return true;

            float x = 0.5f*chisq;
            float negxdivln2;       // set below
            float l2gamma, nig;     // log-gamma and normalized incomplete gamma

            // Custom from-scratch normalized incomplete gamma Q(0.5f*ndegree, 0.5f*chisq) implementation
            // that is faster than usual power-series approx for the special cases of this program.
            // Iterate using the formulas:
            // - Q(a,x) = igamma(a,x) / gamma(a);
            // - twice_a = ndegree is integer.
            // - igamma(1/2, x) = sqrt(pi).erfc(sqrt(x)), gamma(1/2) = sqrt(pi)
            // - igamma(1) = e^-x, gamma(1) = 1
            // - igamma(a, x) = (a-1)*igamma(a-1,x) + x^(a-1).e^-x  (for a>=3/2)
            // - gamma(a) = (a-1)*gamma(a-1) (for a>=3/2)
            unsigned int twice_a = ndegree & 1;
            if (twice_a) {
                // twice a is odd, start with gamma(1/2)
                nig = erfcf(sqrtf(x));
                if (ndegree==1) return nig > matchLevel;
                l2gamma = 0.825748064739; // = log2(sqrt(pi))
                // loop constant
                negxdivln2 = x / (-0.69314718056f);
            } else {
                // start with gamma(1)
                negxdivln2 = x / (-0.69314718056f);
                nig = CausalStateAnalyzer::my_exp2f(negxdivln2);
                if (ndegree==2) return nig > matchLevel;
                l2gamma = 0.0f;
            }
            float halflog2x = 0.5f * CausalStateAnalyzer::my_log2f(x); // constant through the loop

            // now iterate using the above formulas
            for (twice_a += 2; twice_a <= ndegree; twice_a += 2) {
                l2gamma += cachedLogHalf(twice_a);
                float sumTerm = CausalStateAnalyzer::my_exp2f(twice_a * halflog2x + negxdivln2 - l2gamma);
                nig += sumTerm;
                // sum has converged
                if (sumTerm < 1e-6) return nig > matchLevel;
                // early break, since we're always adding >0 numbers
                if (nig > matchLevel) return true;
            }
            return false; // since true would have been detected by the loop
        }

        inline float cachedLogHalf(unsigned int n) {
            static std::vector<float> cachedLog;
            static std::vector<bool> hasLog;
            if (hasLog.size() > n && hasLog[n]) return cachedLog[n];
            if (hasLog.size() < n+1) {hasLog.resize(n+1); cachedLog.resize(n+1);}
            hasLog[n] = true; cachedLog[n] = CausalStateAnalyzer::my_log2f(n) - 1.0f;
            return cachedLog[n];
        }

#ifdef MESINCOM_SERIALIZE
        template<class Archive> void serialize(Archive & ar, const unsigned int version) {
            ar & boost::serialization::base_object< std::vector< std::pair<FutureLightCone, int> > >(*this);
            ar & total;
            ar & cluster; // deep serialization
        }

#endif
    };

    inline static float my_exp2f(float x) {
#ifndef NO_ASSEMBLY
        float ret;
        // from gnu libm, stripped down to cover only relevant cases here
        asm(
        "fld     %%st(0)\n"
        "frndint\n"                    // int(x), current rounding mode doesn't matter
        "fsubr   %%st(0),%%st(1)\n"    // fract(x), between -1.0 and 1.0, OK for f2xm1
        "fxch\n"
        "f2xm1\n"                      // 2^fract(x) - 1
        "fld1\n"
        "faddp\n"                      // 2^fract(x)
        "fscale\n"                     // 2^x
        "fstp    %%st(1)\n"
        : "=t" (ret)
        : "0" (x)
        : "st(1)"
        );
        return ret;
#else
        return exp2f(x);
#endif
    }

    inline static float my_log2f(float x) {
#ifndef NO_ASSEMBLY
        float ret;
        // custom implementation
        asm(
        "fld1\n"
        "fxch\n"
        "fyl2x"
        : "=t" (ret)
        : "0" (x)
        : "st(1)"
        );
        return ret;
#else
        return log2f(x);
#endif
    }


    // A cluster contains a an average future distribution,
    // mapping pasts to cluster is external
    struct Cluster {

        // The average distribution
        FutureDistribution distribution;
        // The number of times this cluster was observed
        int ninstances;

        // cluster complexity is -log(ninstances/totalInstances);
        // keep log(ninstances) since this is only modified when changing dist
        // then log(totalInstances) is computed only once for all clusters, and
        // a simple subtraction gives the complexity (instead of a log)
        protected:
        float logNinstances;
        bool needRecomputeLogNinstances;
        public:

        // on-demand, only if necessary, log2
        inline float getLogNinstances() {
            if (needRecomputeLogNinstances) {
                logNinstances = CausalStateAnalyzer::my_log2f(ninstances);
                needRecomputeLogNinstances = false;
            }
            return logNinstances;
        }

        inline Cluster() : ninstances(0) {
            distribution.cluster.setnew(this);
        }

        inline Cluster(FutureDistribution* dist) : ninstances(0), needRecomputeLogNinstances(true) {
            distribution.cluster.setnew(this);
            merge(dist);
        }


        inline void merge(FutureDistribution* dist) {
            assert(dist->cluster.empty());
            distribution.add(*dist);
            ninstances += dist->total;
            needRecomputeLogNinstances = true;
            dist->cluster = distribution.cluster;
        }

        inline void remove(FutureDistribution* dist) {
            assert(dist->cluster.get()==this);
            distribution.sub(*dist);
            ninstances -= dist->total;
            assert(ninstances>=0); // and log = -inf doesn't matter for destroyed cluster
            needRecomputeLogNinstances = true;
            dist->cluster.set(0);
        }

        // Note: the other cluster will be deleted after the merge
        inline void merge(Cluster* cluster) {
            distribution.add(cluster->distribution);
            ninstances += cluster->ninstances;
            needRecomputeLogNinstances = true;
            // update all shared ptrs toward the other cluster to point to this one
            for (typename std::list<CountedPtr*>::iterator it = cluster->backrefs.begin(); it != cluster->backrefs.end(); ++it) (*it)->cluster = this;
            // and take the other cluster shared ptrs
            backrefs.splice(backrefs.end(), cluster->backrefs);
        }

#ifdef MESINCOM_SERIALIZE
        template<class Archive> void serialize(Archive & ar, const unsigned int version) {
            ar & distribution;
            ar & ninstances & logNinstances & needRecomputeLogNinstances;
        }
#endif

        // update all shared ptrs when this cluster fusions with another
        std::list<CountedPtr*> backrefs;

        // break self-ref loop on destruction
        ~Cluster() {
            distribution.cluster.set(0);
            // All pointers to this cluster are invalidated
            for (typename std::list<CountedPtr*>::iterator it = backrefs.begin(); it != backrefs.end(); ++it) (*it)->cluster = 0;
        }
    };

    // All existing clusters, irrespectively of shared ptrs: clusters.size() is the real number of clusters
    std::vector<Cluster*> clusters;

#ifndef NO_TR1
    typedef std::tr1::unordered_map<PastLightCone,FutureDistribution*> DistributionMap;
    typedef std::tr1::unordered_map<PastLightCone, DeltaObsHolder> DeltaObsMap;
#else
    typedef std::map<PastLightCone,FutureDistribution*> DistributionMap;
    typedef std::map<PastLightCone, DeltaObsHolder> DeltaObsMap;
#endif

    DistributionMap observedDistributions;
    DeltaObsMap deltaObs;

    int totalInstances;
    // subtract each cluster logNinstances from that to get that cluster complexity

protected: // some members are uninitialized, force usage of accessor
    float minLogNinstances;
    float maxLogNinstances;
    Cluster* cMinLogNinstances;
    Cluster* cMaxLogNinstances;
    bool needToRecomputeMinMax;
    float logTotalInstances;
    bool needToRecomputeLogTotalInstances;
public:

    // return the min cluster logNinstances
    // is O(N) minimally, only when necessary, otherwise O(1)
    inline float getMinLogNinstances() {
        if (needToRecomputeMinMax) recomputeMinMaxLogNinstances();
        return minLogNinstances;
    }
    // idem for max
    inline float getMaxLogNinstances() {
        if (needToRecomputeMinMax) recomputeMinMaxLogNinstances();
        return maxLogNinstances;
    }

    inline void recomputeMinMaxLogNinstances() {
        minLogNinstances = 1e40f;
        maxLogNinstances = -1e40f;
        for (unsigned int i=0; i<clusters.size(); ++i) {
            if (clusters[i]->getLogNinstances()<minLogNinstances) {
                minLogNinstances = clusters[i]->getLogNinstances();
                cMinLogNinstances = clusters[i];
            }
            if (clusters[i]->getLogNinstances()>maxLogNinstances) {
                maxLogNinstances = clusters[i]->getLogNinstances();
                cMaxLogNinstances = clusters[i];
            }
        }
        needToRecomputeMinMax = false;
    }

    inline float getLogTotalInstances() {
        if (needToRecomputeLogTotalInstances) {
            logTotalInstances = my_log2f(totalInstances);
            needToRecomputeLogTotalInstances = false;
        }
        return logTotalInstances;
    }

    int deltaObsSeq;
    // Each cluster is back-refed and changed if it later fusions with another cluster due to a subsequent observation!
    typedef std::vector<ClusterPtr> ObservedClusters;
    ObservedClusters userSequence;
    std::vector<float> userComplexities;
    std::vector<float> userScaledComplexities;

    void reCluster() {

        userSequence.resize(deltaObs.size()); // All are null refs for now

        int deltaInstances = 0;

        // A proxy vector to access the map elements randomly
        // store pair pointers and NOT map iterators
        std::vector<const typename DeltaObsMap::value_type*> randomAccessor;
        randomAccessor.reserve(deltaObs.size());

        typename DeltaObsMap::const_iterator doEnd = deltaObs.end();
        for (typename DeltaObsMap::const_iterator doElt = deltaObs.begin(); doElt != doEnd; ++doElt) {
            randomAccessor.push_back(&(*doElt));
        }

        // process the observations in random order
        for (int dobsN = randomAccessor.size(); dobsN>0; --dobsN) {
            int dobsI = randfun(dobsN);
            const typename DeltaObsMap::value_type* doit = randomAccessor[dobsI];
            swap(randomAccessor[dobsI], randomAccessor[dobsN-1]);

            // total may be 0, but with one future added and another removed
            // so the distribution changes.
            // see after this main loop too
            deltaInstances += doit->second.total;

            const PastLightCone& plc = doit->first;

            // find distribution for that past, if it exists
            typename DistributionMap::iterator distIt = observedDistributions.find(plc);

            FutureDistribution* observedDistribution;

            Cluster* oldCluster = 0;

            // create a new distribution if this plc was not found
            if (distIt == observedDistributions.end()) {
                // Unlike the above, a past that wasn't present may not have negative changes
                // so all flc count must be >=0
                assert(doit->second.total>=0);
                // this may happen when adding and removing the same observation before commit
                if (doit->second.total == 0) {
                    // but then no cluster match for the user sequence
                    userSequence[doit->second.seqNum].set(0);
                    continue;
                }
                observedDistribution = new FutureDistribution(doit->second);
                observedDistributions[plc] = observedDistribution;
            }
            // found: remove the distribution from its cluster and update both
            else {

                observedDistribution = distIt->second;
                Cluster* cluster = observedDistribution->cluster.get();

                // update cluster distrib. This may change the min/max
                cluster->remove(observedDistribution);
                // it doesn't matter whether other clusters were also min/max
                // but when the unique min/max is changed, we have to check
                bool wasMinCluster = cluster == cMinLogNinstances;
                bool wasMaxCluster = cluster == cMaxLogNinstances;

                // destroy the cluster if it has become empty
                if (cluster->ninstances==0) {
                    typename std::vector<Cluster*>::iterator it = find(clusters.begin(), clusters.end(), cluster);
                    *it = clusters.back();
                    clusters.pop_back();
                    // Potential hazard with user sequence cannot happen. If a cluster is
                    // valid for one past, it cannot be destroyed for another past later
                    // on since it won't then be empty. Several observations may be made on
                    // the same past, but pasts are processed only once.
                    // Worse case is cluster fusion but that is handled too.
                    delete cluster; cluster = 0;
#ifdef MESINCOM_MEM_TIGHT
                    std::vector<Cluster*>(clusters).swap(clusters);
#endif
                }

                // if needToRecomputeMinMax is set already, don't bother maintaining here
                if (!needToRecomputeMinMax) {
                    if (cluster != 0 && cluster->getLogNinstances() < minLogNinstances) {
                        minLogNinstances = cluster->getLogNinstances();
                        cMinLogNinstances = cluster;
                        wasMinCluster = false;
                    }
                    if (cluster != 0 && cluster->getLogNinstances() > maxLogNinstances) {
                        maxLogNinstances = cluster->getLogNinstances();
                        cMaxLogNinstances = cluster;
                        wasMaxCluster = false;
                    }

                    // need to recompute min/max?
                    if (wasMinCluster || wasMaxCluster) needToRecomputeMinMax = true;
                }

                // now update the observedDistribution & check whether it becomes empty (removed obs)
                if (observedDistribution->update(doit->second) == 0) {
                    // that past has no future anymore... cleanup and done!
                    observedDistributions.erase(distIt);
                    delete observedDistribution;
                    userSequence[doit->second.seqNum].set(0);
                    continue;
                }

                oldCluster = cluster;
            }

            // The old cluster has changed. Check it doesn't match other clusters, and if so, fusion them
            if (oldCluster) {
                bool fusionAgain = false;
                do {
                    vector<int> matchingClusters;
                    for (unsigned int i = 0; i< clusters.size(); ++i) {
                        if (clusters[i]==oldCluster || oldCluster->distribution.match(clusters[i]->distribution, matchLevel)) {
                            matchingClusters.push_back(i);
                        }
                    }
                    // now get all matching clusters in order
                    oldCluster = clusters[matchingClusters[0]];
                    fusionAgain = matchingClusters.size() > 1;

                    // Fusion all new clusters (looping from index end so swap does not invalidate next elements in loop)
                    for (int i=matchingClusters.size()-1; i>0; --i) {
                        Cluster *mergedOne = clusters[matchingClusters[i]];
                        // fusion other cluster to the current cluster
                        oldCluster->merge(mergedOne);
                        clusters[matchingClusters[i]] = clusters.back();
                        clusters.pop_back();
                        delete mergedOne;
                        needToRecomputeMinMax = true;
                    }
                } while (fusionAgain);
            }

            // not randomized: this is an exhaustive matching anyway, and we need sorted index property for the O(N) fusion
            vector<int> matchingClusters;
            for (unsigned int i = 0; i< clusters.size(); ++i) {
                if (observedDistribution->match(clusters[i]->distribution, matchLevel)) {
                    matchingClusters.push_back(i);
                }
            }

            Cluster* cluster;

            // Insert the distribution in the cluster just found, or create a new
            // one. This may be or not the old cluster.

            // no cluster found, create a new one
            if (matchingClusters.empty()) {
                cluster = new Cluster(observedDistribution);
                clusters.push_back(cluster);
            }
            // or update it
            else {
                cluster = clusters[matchingClusters[0]];
                cluster->merge(observedDistribution);

                while (matchingClusters.size()>1) {
                    // Fusion all new clusters (looping from index end so swap does not invalidate next elements in loop)
                    for (int i=matchingClusters.size()-1; i>0; --i) {
                        Cluster *mergedOne = clusters[matchingClusters[i]];
                        // fusion other cluster to the current cluster
                        cluster->merge(mergedOne);
                        clusters[matchingClusters[i]] = clusters.back();
                        clusters.pop_back();
                        delete mergedOne;
                        needToRecomputeMinMax = true;
                    }

                    matchingClusters.clear();
                    for (unsigned int i = 0; i< clusters.size(); ++i) {
                        if (clusters[i]==cluster || cluster->distribution.match(clusters[i]->distribution, matchLevel)) {
                            matchingClusters.push_back(i);
                        }
                    }
                    cluster = clusters[matchingClusters[0]];
                }
            }

            userSequence[doit->second.seqNum].set(cluster);

            // if needToRecomputeMinMax is set already, don't bother maintaining here
            if (!needToRecomputeMinMax) {
                if (cluster->getLogNinstances() < minLogNinstances) {
                    minLogNinstances = cluster->getLogNinstances();
                    cMinLogNinstances = cluster;
                }
                if (cluster->getLogNinstances() > maxLogNinstances) {
                    maxLogNinstances = cluster->getLogNinstances();
                    cMaxLogNinstances = cluster;
                }
            }
        }

        // update total & log
        totalInstances += deltaInstances;
        needToRecomputeLogTotalInstances = true;
    }

    // debug routine.
#if defined(NDEBUG)
    inline void assertNoClusterMatch() {}
#else
    void assertNoClusterMatch() {
        for (unsigned int i=0; i<clusters.size(); ++i) for (unsigned int j=0; j<clusters.size(); ++j) {
            if (i==j) continue;
            assert(!clusters[j]->distribution.match(clusters[i]->distribution, matchLevel));
        }
    }
#endif

    inline ~CausalStateAnalyzer() {
        while (!clusters.empty()) {
            delete clusters.back();
            clusters.pop_back();
        }
        typename DistributionMap::iterator ditend = observedDistributions.end();
        for (typename DistributionMap::iterator dit = observedDistributions.begin(); dit!=ditend; ++dit) {
            delete dit->second;
        }
        observedDistributions.clear();
    }


#ifdef MESINCOM_SERIALIZE
    template<class Archive> void serialize(Archive & ar, const unsigned int version) {
        ar & matchLevel;
        ar & clusters & observedDistributions & deltaObs;
        ar & totalInstances & logTotalInstances;
        ar & minLogNinstances & maxLogNinstances;
        ar & cMinLogNinstances & cMaxLogNinstances;
        ar & needToRecomputeMinMax & needToRecomputeLogTotalInstances;
        ar & deltaObsSeq & userSequence;
        ar & userComplexities & userScaledComplexities;
        ar & randfun;
    }
#endif


///////////////////////////////////////////////////////////////////////////
// USER API
///////////////////////////////////////////////////////////////////////////

    // Create an Analyzer
    // the future distributions will be clustered when they match with the given significance level
    // Default is 0.04321, just so it's not 0.05.
    inline CausalStateAnalyzer(float significanceLevel = 0.04321)
    : matchLevel(1.0f - significanceLevel), totalInstances(0), needToRecomputeMinMax(true), needToRecomputeLogTotalInstances(true), deltaObsSeq(0) {
    }

    // Add a new observed future light cone for a given past light cone
    // The observation is just stored for now, reclustering is a separate
    // routine: Call commitObservations when you wish so. This allows you
    // to make the algorithm fully incremental (commit after each change)
    // to fully post-processing (1 big commit at the end).
    // A reference index is returned, that will match the returned vector
    // when clustering (see that function).
    inline int addObservation(const PastLightCone& plc, const FutureLightCone& flc) {
        DeltaObsHolder& holder = deltaObs[plc];              // may insert a new one
        ++holder[flc];                                       // delta counts for that future
        ++holder.total;                                      // delta total
        if (holder.seqNum==-1) holder.seqNum = deltaObsSeq++;// need new index ?
        return holder.seqNum;                                // sequence index for the user
    }

    // Useful for systems changing over time or space. When "old" observations
    // have a different enough distribution from current observation (non-
    // stationary systems) and should be discarded.
    // See also the comments for the add function.
    inline int removeObservation(const PastLightCone& plc, const FutureLightCone& flc) {
        DeltaObsHolder& holder = deltaObs[plc];              // may insert a new one
        --holder[flc];                                       // delta counts for that future
        --holder.total;                                      // delta total
        if (holder.seqNum==-1) holder.seqNum = deltaObsSeq++;// need new index ?
        return holder.seqNum;                                // sequence index for the user
    }

    // Take into account all new and removed observations.
    // A vector reference is returned, containing pointers to the clusters
    // found for the observations (or null when definitly removing a past).
    // Indices in that vector match the returned values of the
    // (add|remove)observation functions.
    // The vector is reset each commit, and undefined when changing the observations
    inline ObservedClusters& commitObservations() {
        if (!deltaObs.empty()) reCluster();
        deltaObsSeq = 0;
        deltaObs.clear();
        return userSequence;
    }

    // get absolute cluster complexity
    inline float getComplexity(ClusterPtr cluster) {
        return getLogTotalInstances() - cluster->getLogNinstances();
    }

    // get cluster complexity scaled to 0..1 relatively to the other clusters
    inline float getScaledComplexity(ClusterPtr cluster) {
        float ml = getMaxLogNinstances();
        float scaleFactor = ml - getMinLogNinstances();
        if (scaleFactor!=0.0) scaleFactor = 1.0 / scaleFactor;
        return (ml - cluster->getLogNinstances()) * scaleFactor;
    }

    // vectorized versions of the previous functions.
    // See commitObservations for what the vector is.
    inline std::vector<float>& getComplexities() {
        userComplexities.clear();
        userComplexities.reserve(userSequence.size());
        float lti = getLogTotalInstances();
        for (typename ObservedClusters::iterator cit = userSequence.begin(); cit!=userSequence.end(); ++cit) {
            Cluster* cluster = cit->get();
            if (cluster) userComplexities.push_back(lti - cluster->getLogNinstances());
            else userComplexities.push_back(0.0f);
        }
        return userComplexities;
    }

    // vectorized versions of the previous functions.
    // See commitObservations for what the vector is.
    inline std::vector<float>& getScaledComplexities() {
        userScaledComplexities.clear();
        userScaledComplexities.reserve(userSequence.size());
        float ml = getMaxLogNinstances();
        float scaleFactor = ml - getMinLogNinstances();
        if (scaleFactor!=0.0) scaleFactor = 1.0 / scaleFactor;
        for (typename ObservedClusters::iterator cit = userSequence.begin(); cit!=userSequence.end(); ++cit) {
            Cluster* cluster = cit->get();
            if (cluster) userScaledComplexities.push_back((ml - cluster->getLogNinstances()) * scaleFactor);
            else userScaledComplexities.push_back(0.0f);
        }
        return userScaledComplexities;
    }

    // return the cluster for a given past, if the cluster exists
    // The pointer is only valid after a commit and before adding/removing observations
    inline ClusterPtr getCluster(const PastLightCone& plc) {
        typename DistributionMap::iterator distIt = observedDistributions.find(plc);
        if (distIt == observedDistributions.end()) return ClusterPtr();
        return distIt->second->cluster;
    }

    // Get the global complexity of the analyzer
    inline float getGlobalComplexity() {
        // This could be made incremental
        float globalComplexity = 0.0;
        for (unsigned int i=0; i<clusters.size(); ++i) {
            Cluster* c = clusters[i];
            globalComplexity += c->ninstances * c->getLogNinstances();
        }
        return getLogTotalInstances() - globalComplexity / totalInstances;
    }
};

// Resolve forward refs
template<class PastLightCone, class FutureLightCone, typename RandomFunctor>
inline CausalStateAnalyzer<PastLightCone,FutureLightCone,RandomFunctor>::CountedPtr::CountedPtr(Cluster* c) : cluster(c), count(1) {assert(c); c->backrefs.push_back(this);}
template<class PastLightCone, class FutureLightCone, typename RandomFunctor>
inline CausalStateAnalyzer<PastLightCone,FutureLightCone,RandomFunctor>::CountedPtr::~CountedPtr() {if (cluster) cluster->backrefs.remove(this);}
template<class PastLightCone, class FutureLightCone, typename RandomFunctor>
inline void CausalStateAnalyzer<PastLightCone,FutureLightCone,RandomFunctor>::ClusterPtr::set(Cluster* cluster) {
    if (cluster) *this = cluster->distribution.cluster;
    else if (indirect) {indirect->unref(); indirect = 0;}
}

