// -*- C++ -*-
#include "Rivet/Analysis.hh"
#include "Rivet/Projections/Beam.hh"
#include "Rivet/Projections/FinalState.hh"
#include "Rivet/Projections/UnstableParticles.hh"

namespace Rivet {


  /// @brief J/Psi, psi(2S) -> Sigma+ Sigmabar-
  class BESIII_2020_I1791570 : public Analysis {
  public:

    /// Constructor
    RIVET_DEFAULT_ANALYSIS_CTOR(BESIII_2020_I1791570);


    /// @name Analysis methods
    /// @{

    /// Book histograms and initialise projections before the run
    void init() {

      // Initialise and register projections
      declare(Beam(), "Beams");
      declare(UnstableParticles(), "UFS");
      declare(FinalState(), "FS");

      // Book histograms

      size_t ih = 1;
      for (double eVal : allowedEnergies()) {

        const string en = toString(round(eVal/MeV));
        if (isCompatibleWithSqrtS(eVal, 1e-3))  _sqs = en;

        book(_h[en+"cThetaL"],"/TMP/cThetaL+"+en,20,-1.,1.);
        book(_h[en+"T1"], "/TMP/T1_"+en,20,-1.,1.);
        book(_h[en+"T2"], "/TMP/T2_"+en,20,-1.,1.);
        book(_h[en+"T3"], "/TMP/T3_"+en,20,-1.,1.);
        book(_h[en+"T4"], "/TMP/T4_"+en,20,-1.,1.);
        book(_h[en+"T5"], "/TMP/T5_"+en,20,-1.,1.);
        book(_h[en+"mu"], 1, 1, ih);
        ++ih;
      }
      raiseBeamErrorIf(_sqs.empty());
    }

    void findChildren(const Particle& p, map<long,int>& nRes, int& ncount) const {
      for (const Particle& child : p.children()) {
        if (child.children().empty()) {
          nRes[child.pid()] -= 1;
          --ncount;
        }
        else {
          findChildren(child,nRes,ncount);
        }
      }
    }

    /// Perform the per-event analysis
    void analyze(const Event& event) {
      // get the axis, direction of incoming electron
      const ParticlePair& beams = apply<Beam>(event, "Beams").beams();
      Vector3 axis;
      if (beams.first.pid()>0)  axis = beams.first .mom().p3().unit();
      else                      axis = beams.second.mom().p3().unit();
      // types of final state particles
      const FinalState& fs = apply<FinalState>(event, "FS");
      map<long,int> nCount;
      int ntotal(0);
      for (const Particle& p : fs.particles()) {
        nCount[p.pid()] += 1;
        ++ntotal;
      }
      // loop over Sigma+ baryons
      const UnstableParticles & ufs = apply<UnstableParticles>(event, "UFS");
      Particle Sigma,SigBar;
      bool matched(false);
      for (const Particle& p :  ufs.particles(Cuts::abspid==3222)) {
        if(p.children().empty()) continue;
        map<long,int> nRes=nCount;
        int ncount = ntotal;
        findChildren(p,nRes,ncount);
        matched=false;
        // check for antiparticle
        for (const Particle& p2 :  ufs.particles(Cuts::pid==-p.pid())) {
          if (p2.children().empty()) continue;
          map<long,int> nRes2=nRes;
          int ncount2 = ncount;
          findChildren(p2,nRes2,ncount2);
          if (ncount2==0) {
            matched = true;
            for (const auto& val : nRes2) {
              if (val.second!=0) {
                matched = false;
                break;
              }
            }
            // fond baryon and antibaryon
            if (matched) {
              if (p.pid()>0) {
                Sigma = p;
                SigBar = p2;
              }
              else {
                Sigma = p2;
                SigBar = p;
              }
              break;
            }
          }
        }
        if (matched) break;
      }
      if (!matched) vetoEvent;
      // find proton
      Particle proton;
      matched = false;
      for (const Particle & p : Sigma.children()) {
        if (p.pid()==2212) {
          matched=true;
          proton=p;
        }
        else if (p.pid()!=111) {
          matched = false;
          break;
        }
      }
      if (!matched) vetoEvent;
      // find antiproton
      Particle pbar;
      matched = false;
      for (const Particle& p : SigBar.children()) {
        if (p.pid()==-2212) {
          matched=true;
          pbar=p;
        }
        else if (p.pid()!=111) {
          matched = false;
          break;
        }
      }
      if (!matched) vetoEvent;
      // boost to the Sigma rest frame
      LorentzTransform boost1 = LorentzTransform::mkFrameTransformFromBeta(Sigma.mom().betaVec());
      Vector3 e1z = Sigma.mom().p3().unit();
      Vector3 e1y = e1z.cross(axis).unit();
      Vector3 e1x = e1y.cross(e1z).unit();
      Vector3 axis1 = boost1.transform(proton.mom()).p3().unit();
      double n1x(e1x.dot(axis1)),n1y(e1y.dot(axis1)),n1z(e1z.dot(axis1));
      // boost to the Sigma bar
      LorentzTransform boost2 = LorentzTransform::mkFrameTransformFromBeta(SigBar.mom().betaVec());
      Vector3 axis2 = boost2.transform(pbar.mom()).p3().unit();
      double n2x(e1x.dot(axis2)),n2y(e1y.dot(axis2)),n2z(e1z.dot(axis2));
      double cosL = axis.dot(Sigma.mom().p3().unit());
      double sinL = sqrt(1.-sqr(cosL));
      double T1 = sqr(sinL)*n1x*n2x+sqr(cosL)*n1z*n2z;
      double T2 = -sinL*cosL*(n1x*n2z+n1z*n2x);
      double T3 = -sinL*cosL*n1y;
      double T4 = -sinL*cosL*n2y;
      double T5 = n1z*n2z-sqr(sinL)*n1y*n2y;
      double mu = -(n1y-n2y);
      _h[_sqs+"T1"]->fill(cosL,T1);
      _h[_sqs+"T2"]->fill(cosL,T2);
      _h[_sqs+"T3"]->fill(cosL,T3);
      _h[_sqs+"T4"]->fill(cosL,T4);
      _h[_sqs+"T5"]->fill(cosL,T5);
      _h[_sqs+"mu"]->fill(cosL,mu);
      _h[_sqs+"cThetaL"]->fill(cosL);
    }


    pair<double,pair<double,double> > calcAlpha0(const Histo1DPtr& hist) const {
      if (hist->numEntries()==0.) return make_pair(0.,make_pair(0.,0.));
      double d = 3./(pow(hist->xMax(),3)-pow(hist->xMin(),3));
      double c = 3.*(hist->xMax()-hist->xMin())/(pow(hist->xMax(),3)-pow(hist->xMin(),3));
      double sum1(0.),sum2(0.),sum3(0.),sum4(0.),sum5(0.);
      for (const auto& bin : hist->bins()) {
        double Oi = bin.sumW();
        if (Oi==0.) continue;
        double a =  d*(bin.xMax() - bin.xMin());
        double b = d/3.*(pow(bin.xMax(),3) - pow(bin.xMin(),3));
        double Ei = bin.errW();
        sum1 +=   a*Oi/sqr(Ei);
        sum2 +=   b*Oi/sqr(Ei);
        sum3 += sqr(a)/sqr(Ei);
        sum4 += sqr(b)/sqr(Ei);
        sum5 +=    a*b/sqr(Ei);
      }
      // calculate alpha
      double alpha = (-c*sum1 + sqr(c)*sum2 + sum3 - c*sum5)/(sum1 - c*sum2 + c*sum4 - sum5);
      // and error
      double cc = -pow((sum3 + sqr(c)*sum4 - 2*c*sum5),3);
      double bb = -2*sqr(sum3 + sqr(c)*sum4 - 2*c*sum5)*(sum1 - c*sum2 + c*sum4 - sum5);
      double aa =  sqr(sum1 - c*sum2 + c*sum4 - sum5)*(-sum3 - sqr(c)*sum4 + sqr(sum1 - c*sum2 + c*sum4 - sum5) + 2*c*sum5);
      double dis = sqr(bb)-4.*aa*cc;
      if (dis>0.) {
        dis = sqrt(dis);
        return make_pair(alpha,make_pair(0.5*(-bb+dis)/aa,-0.5*(-bb-dis)/aa));
      }
      else {
        return make_pair(alpha,make_pair(0.,0.));
      }
    }

    pair<double,double> calcCoeff(size_t imode, const Histo1DPtr& hist) const {
      if (hist->numEntries()==0.) return make_pair(0.,0.);
      double sum1(0.),sum2(0.);
      for (const auto& bin : hist->bins()) {
        double Oi = bin.sumW();
        if (Oi==0.) continue;
        double ai(0.),bi(0.);
        if (imode==0) {
          bi = (pow(1.-sqr(bin.xMin()),1.5) - pow(1.-sqr(bin.xMax()),1.5))/3.;
        }
        else if (imode>=2 && imode<=4) {
          bi = ( pow(bin.xMin(),3)*( -5. + 3.*sqr(bin.xMin()))  +
                 pow(bin.xMax(),3)*(  5. - 3.*sqr(bin.xMax())))/15.;
        }
        else {
          assert(false);
        }
        double Ei = bin.errW();
        sum1 += sqr(bi/Ei);
        sum2 += bi/sqr(Ei)*(Oi-ai);
      }
      return make_pair(sum2/sum1,sqrt(1./sum1));
    }

    /// Normalise histograms etc., after the run
    void finalize() {

      size_t ih = 1;
      for (double eVal : allowedEnergies()) {

        const string en = toString(round(eVal/MeV));

        const double sf = _h[en+"cThetaL"]->sumW();
        normalize(_h[en+"cThetaL"]);
        scale(_h[en+"T1"], 1.0/sf);
        scale(_h[en+"T2"], 1.0/sf);
        scale(_h[en+"T3"], 1.0/sf);
        scale(_h[en+"T4"], 1.0/sf);
        scale(_h[en+"T5"], 1.0/sf);
        scale(_h[en+"mu"], 2.0/sf);

        // calculate alpha0
        pair<double,pair<double,double> > alpha0 = calcAlpha0(_h[en+"cThetaL"]);
        Estimate0DPtr est;
        book(est,4,1,ih);
        est->set(alpha0.first, alpha0.second);
        double s2 = -1. + sqr(alpha0.first);
        double s3 = 3 + alpha0.first;
        double s1 = sqr(s3);
        // alpha- and alpha+ from proton data
        pair<double,double> c_T2 = calcCoeff(2,_h[en+"T2"]);
        pair<double,double> c_T3 = calcCoeff(3,_h[en+"T3"]);
        pair<double,double> c_T4 = calcCoeff(4,_h[en+"T4"]);
        double s4 = sqr(c_T2.first);
        double s5 = sqr(c_T3.first);
        double s6 = sqr(c_T4.first);
        double disc = s1*s5*s6*(-9.*s2*s4 + 4.*s1*s5*s6);
        if (disc>=0.) {
          disc = sqrt(disc);
          double aM = -sqrt(-1./s2/s6*(2.*s1*s5*s6+disc));
          double aP = c_T4.first/c_T3.first*aM;
          double aM_P = (2*(alpha0.first*c_T4.first*alpha0.second.first + c_T4.second*s2)*(disc + 2*s1*s5*s6)
                         - c_T4.first*s2*(4*s3*c_T3.first*c_T4.first*(c_T3.first*c_T4.first*alpha0.second.first
                         + s3*c_T4.first*c_T3.second +s3*c_T3.first*c_T4.second)
                         + (disc*(- 9*s2*s3*c_T2.first*c_T3.first*c_T4.first* c_T2.second
                         + 9*((1 -  alpha0.first*(3 + 2*alpha0.first))* c_T3.first*c_T4.first*alpha0.second.first
                         - s2*s3*c_T4.first*c_T3.second - s2*s3*c_T3.first*c_T4.second)* s4
                         + 8*(c_T3.first*c_T4.first*alpha0.second.first + s3*c_T4.first*c_T3.second
                         + s3*c_T3.first*c_T4.second)* s1*s5*s6))
                         / (4*pow(3 + alpha0.first,3)*pow(c_T3.first,3)*pow(c_T4.first,3) -9*s2*s3*c_T3.first*c_T4.first*s4)))
                         / (2.*pow(c_T4.first,3)*pow(s2,2)*sqrt(-((disc + 2*s1*s5*s6)/(s2*s6))));
          double aM_M = (2*(alpha0.first*c_T4.first*alpha0.second.second + c_T4.second*s2)*(disc + 2*s1*s5*s6)
                         - c_T4.first*s2*(4*s3*c_T3.first*c_T4.first*(c_T3.first*c_T4.first*alpha0.second.second
                         + s3*c_T4.first*c_T3.second +s3*c_T3.first*c_T4.second)
                         + (disc*(- 9*s2*s3*c_T2.first*c_T3.first*c_T4.first* c_T2.second
                         + 9*((1 - alpha0.first*(3 + 2*alpha0.first))* c_T3.first*c_T4.first*alpha0.second.second
                         - s2*s3*c_T4.first*c_T3.second - s2*s3*c_T3.first*c_T4.second)* s4
                         + 8*(c_T3.first*c_T4.first*alpha0.second.second +  s3*c_T4.first*c_T3.second
                         + s3*c_T3.first*c_T4.second)* s1*s5*s6))
                         / (4*pow(3 + alpha0.first,3)*pow(c_T3.first,3)*pow(c_T4.first,3) -9*s2*s3*c_T3.first*c_T4.first*s4)))
                         / (2.*pow(c_T4.first,3)*pow(s2,2)*sqrt(-((disc + 2*s1*s5*s6)/(s2*s6))));
          double aP_M = (c_T4.first*sqrt(-((disc + 2*s1*s5*s6)/ (s2*s6)))
                        * (-2*c_T3.second -  (2*alpha0.first*c_T3.first*alpha0.second.first)/s2
                        + (c_T3.first*(4*s3*c_T3.first*c_T4.first*(c_T3.first*c_T4.first*alpha0.second.first
                        + s3*c_T4.first*c_T3.second + s3*c_T3.first*c_T4.second)
                        + (disc*(-9*s2*s3*c_T2.first*c_T3.first*c_T4.first* c_T2.second
                        + 9*((1 -  alpha0.first*(3 + 2*alpha0.first))* c_T3.first*c_T4.first*alpha0.second.first
                        - s2*s3*c_T4.first*c_T3.second -  s2*s3*c_T3.first*c_T4.second)* s4
                        + 8*(c_T3.first*c_T4.first*alpha0.second.first +  s3*c_T4.first*c_T3.second
                        + s3*c_T3.first*c_T4.second)* s1*s5*s6))/ (4* pow(3 + alpha0.first,3)* pow(c_T3.first,3) * pow(c_T4.first,3)
                        - 9*s2*s3*c_T3.first*c_T4.first*s4)))/ (disc + 2*s1*s5*s6)))/(2.*pow(c_T3.first,2));
          double aP_P = (c_T4.first*sqrt(-((disc + 2*s1*s5*s6)/ (s2*s6)))
                        * (-2*c_T3.second -  (2*alpha0.first*c_T3.first*alpha0.second.second)/s2
                        + (c_T3.first*(4*s3*c_T3.first*c_T4.first*(c_T3.first*c_T4.first*alpha0.second.second
                        + s3*c_T4.first*c_T3.second + s3*c_T3.first*c_T4.second)
                        + (disc*(-9*s2*s3*c_T2.first*c_T3.first*c_T4.first* c_T2.second
                        +  9*((1 - alpha0.first*(3 + 2*alpha0.first))* c_T3.first*c_T4.first*alpha0.second.second
                        - s2*s3*c_T4.first*c_T3.second -  s2*s3*c_T3.first*c_T4.second)* s4
                        + 8*(c_T3.first*c_T4.first*alpha0.second.second +  s3*c_T4.first*c_T3.second
                        + s3*c_T3.first*c_T4.second)* s1*s5*s6)) / (4* pow(3 + alpha0.first,3)* pow(c_T3.first,3)* pow(c_T4.first,3)
                        - 9*s2*s3*c_T3.first*c_T4.first*s4)))/ (disc + 2*s1*s5*s6)))/(2.*pow(c_T3.first,2));
          book(est,2,1,1);
          est->set(aM, make_pair(-aM_M , -aM_P ));
          book(est,2,1,2);
          est->set(aP, make_pair(-aP_M , -aP_P  ));
          book(est,2,1,3);
          est->set(0.5*(aM-aP), make_pair(0.5*sqrt(sqr(aM_M)+sqr(aP_P)) ,
                                          0.5*sqrt(sqr(aM_P)+sqr(aP_M))));
          // now for Delta
          double sDelta = (-2.*(3. + alpha0.first)*c_T3.first)/(aM*sqrt(1 - sqr(alpha0.first)));
          double cDelta = (-3*(3 + alpha0.first)*c_T2.first)/(aM*aP*sqrt(1 - sqr(alpha0.first)));

          double Delta = asin(sDelta);
          if (cDelta<0.) Delta = M_PI-Delta;
          double ds_P = (-9*c_T2.first*((-1 + alpha0.first)*(1 + alpha0.first) * (3 + alpha0.first)*c_T3.first*c_T4.first*c_T2.second
                        + c_T2.first*c_T4.first*(c_T3.first*(alpha0.second.first + 3*alpha0.first*alpha0.second.first)
                        - (-1 + alpha0.first)*(1 + alpha0.first)*(3 + alpha0.first)*c_T3.second)
                        -  (-1 + alpha0.first)*(1 + alpha0.first)*  (3 + alpha0.first)*c_T2.first*c_T3.first*c_T4.second)*disc)
                        / (pow(1 - pow(alpha0.first,2),1.5)*pow(c_T4.first,3)*pow(-((disc + 2*s1*s5*s6)
                        / (s2*s6)),1.5)*(-9*s2*s4 + 4*s1*s5*s6));
          double ds_M = (-9*c_T2.first*((-1 + alpha0.first)*(1 + alpha0.first) * (3 + alpha0.first)*c_T3.first*c_T4.first*c_T2.second
                        + c_T2.first*c_T4.first*(c_T3.first*(alpha0.second.second + 3*alpha0.first*alpha0.second.second)
                        - (-1 + alpha0.first)*(1 + alpha0.first)*(3 + alpha0.first)*c_T3.second)
                        -  (-1 + alpha0.first)*(1 + alpha0.first)*  (3 + alpha0.first)*c_T2.first*c_T3.first*c_T4.second)*disc)
                        / (pow(1 - pow(alpha0.first,2),1.5)*pow(c_T4.first,3)*pow(-((disc + 2*s1*s5*s6)
                        / (s2*s6)),1.5)*(-9*s2*s4 + 4*s1*s5*s6));
          ds_P /= sqrt(1.-sqr(sDelta));
          ds_M /= sqrt(1.-sqr(sDelta));
          book(est,3,1,ih);
          est->set(Delta/M_PI*180., make_pair( -ds_P/M_PI*180., -ds_M/M_PI*180. ));

        }
        ++ih;
      }
    }

    /// @}


    /// @name Histograms
    /// @{
    map<string,Histo1DPtr> _h;
    string _sqs = "";
    /// @}

  };


  RIVET_DECLARE_PLUGIN(BESIII_2020_I1791570);

}
