// -*- C++ -*-
#include "Rivet/Analysis.hh"
#include "Rivet/Projections/Beam.hh"
#include "Rivet/Projections/FinalState.hh"
#include "Rivet/Projections/UnstableParticles.hh"

namespace Rivet {


  /// @brief J/Psi -> Sigma+ Sigmabar-
  class BESIII_2023_I2655292 : public Analysis {
  public:

    /// Constructor
    RIVET_DEFAULT_ANALYSIS_CTOR(BESIII_2023_I2655292);

    /// @name Analysis methods
    /// @{

    /// Book histograms and initialise projections before the run
    void init() {

      // Initialise and register projections
      declare(Beam(), "Beams");
      declare(UnstableParticles(), "UFS");
      declare(FinalState(), "FS");
      for(unsigned int ix=0;ix<2;++ix) {
        book(_h_T1[ix], "/TMP/T1_"+toString(ix),20,-1.,1.);
        book(_h_T2[ix], "/TMP/T2_"+toString(ix),20,-1.,1.);
        book(_h_T3[ix], "/TMP/T3_"+toString(ix),20,-1.,1.);
        book(_h_T4[ix], "/TMP/T4_"+toString(ix),20,-1.,1.);
        book(_h_T5[ix], "/TMP/T5_"+toString(ix),20,-1.,1.);
        book(_wsum[ix],"/TMP/wsum_"+toString(ix));
      }
      book(_h_cThetaL,"/TMP/cThetaL",20,-1.,1.);
    }

    void findChildren(const Particle & p,map<long,int> & nRes, int &ncount) {
      for (const Particle &child : p.children()) {
        if (child.children().empty()) {
          nRes[child.pid()]-=1;
          --ncount;
        }
        else {
          findChildren(child,nRes,ncount);
        }
      }
    }

    /// Perform the per-event analysis
    void analyze(const Event& event) {
      // get the axis, direction of incoming electron
      const ParticlePair& beams = apply<Beam>(event, "Beams").beams();
      Vector3 axis;
      if (beams.first.pid()>0) axis = beams.first.mom().p3().unit();
      else                     axis = beams.second.mom().p3().unit();
      // types of final state particles
      const FinalState& fs = apply<FinalState>(event, "FS");
      map<long,int> nCount;
      int ntotal(0);
      for (const Particle& p : fs.particles()) {
        nCount[p.pid()] += 1;
        ++ntotal;
      }
      // loop over Sigma+ baryons
      const UnstableParticles & ufs = apply<UnstableParticles>(event, "UFS");
      Particle Sigma,SigBar;
      bool matched(false);
      for (const Particle& p :  ufs.particles(Cuts::abspid==3222)) {
        if (p.children().empty()) continue;
        map<long,int> nRes=nCount;
        int ncount = ntotal;
        findChildren(p,nRes,ncount);
        matched=false;
        // check for antiparticle
        for (const Particle& p2 :  ufs.particles(Cuts::pid==-p.pid())) {
          if (p2.children().empty()) continue;
          map<long,int> nRes2=nRes;
          int ncount2 = ncount;
          findChildren(p2,nRes2,ncount2);
          if(ncount2==0) {
            matched = true;
            for(const auto& val : nRes2) {
              if (val.second!=0) {
                matched = false;
                break;
              }
            }
            // fond baryon and antibaryon
            if (matched) {
              if (p.pid()>0) {
                Sigma = p;
                SigBar = p2;
              }
              else {
                Sigma = p2;
                SigBar = p;
              }
              break;
            }
          }
        }
        if (matched) break;
      }
      if (!matched) vetoEvent;
      // scattering angle
      const double cosL = axis.dot(Sigma.mom().p3().unit());
      const double sinL = sqrt(1.-sqr(cosL));
      _h_cThetaL->fill(cosL);
      // decay of the Sigma+
      Particle baryon;
      int imode[2]={-1,-1};
      if (Sigma.children()[0].pid()==2212 && Sigma.children()[1].pid()==111) {
        baryon = Sigma.children()[0];
        imode[0] = 0;
      }
      else if (Sigma.children()[1].pid()==2212 && Sigma.children()[0].pid()==111) {
        baryon = Sigma.children()[1];
        imode[0] = 0;
      }
      else if (Sigma.children()[0].pid()==2112 && Sigma.children()[1].pid()==211) {
        baryon = Sigma.children()[0];
        imode[0] = 1;
      }
      else if (Sigma.children()[1].pid()==2112 && Sigma.children()[0].pid()==211) {
        baryon = Sigma.children()[1];
        imode[0] = 1;
      }
      if (imode[0]<0) vetoEvent;
      // decay of the Sigmabar-
      Particle abaryon;
      if (SigBar.children()[0].pid()==-2212 && SigBar.children()[1].pid()==111) {
        abaryon = SigBar.children()[0];
        imode[1] = 0;
      }
      else if (SigBar.children()[1].pid()==-2212 && SigBar.children()[0].pid()==111) {
        abaryon = SigBar.children()[1];
        imode[1] = 0;
      }
      else if (SigBar.children()[0].pid()==-2112 && SigBar.children()[1].pid()==-211) {
        abaryon = SigBar.children()[0];
        imode[1] = 1;
      }
      else if (SigBar.children()[1].pid()==-2112 && SigBar.children()[0].pid()==-211) {
        abaryon = SigBar.children()[1];
        imode[1] = 1;
      }
      if (imode[1]<0) vetoEvent;
      if (imode[0]==imode[1]) vetoEvent;
      // boost to the Sigma rest frame
      LorentzTransform boost1 = LorentzTransform::mkFrameTransformFromBeta(Sigma.mom().betaVec());
      Vector3 e1z = Sigma.mom().p3().unit();
      Vector3 e1y = e1z.cross(axis).unit();
      Vector3 e1x = e1y.cross(e1z).unit();
      Vector3 axis1 = boost1.transform(baryon.mom()).p3().unit();
      double n1x(e1x.dot(axis1)),n1y(e1y.dot(axis1)),n1z(e1z.dot(axis1));
      // boost to the Sigma bar
      LorentzTransform boost2 = LorentzTransform::mkFrameTransformFromBeta(SigBar.mom().betaVec());
      Vector3 axis2 = boost2.transform(abaryon.mom()).p3().unit();
      double n2x(e1x.dot(axis2)),n2y(e1y.dot(axis2)),n2z(e1z.dot(axis2));
      double T1 = sqr(sinL)*n1x*n2x+sqr(cosL)*n1z*n2z;
      double T2 = -sinL*cosL*(n1x*n2z+n1z*n2x);
      double T3 = -sinL*cosL*n1y;
      double T4 = -sinL*cosL*n2y;
      double T5 = n1z*n2z-sqr(sinL)*n1y*n2y;
      _h_T1[imode[0]]->fill(cosL,T1);
      _h_T2[imode[0]]->fill(cosL,T2);
      _h_T3[imode[0]]->fill(cosL,T3);
      _h_T4[imode[0]]->fill(cosL,T4);
      _h_T5[imode[0]]->fill(cosL,T5);
      _wsum[imode[0]]->fill();
    }

    pair<double,pair<double,double> > calcAlpha0(Histo1DPtr hist) const {
      if (hist->numEntries()==0.) return make_pair(0.,make_pair(0.,0.));
      double d = 3./(pow(hist->xMax(),3)-pow(hist->xMin(),3));
      double c = 3.*(hist->xMax()-hist->xMin())/(pow(hist->xMax(),3)-pow(hist->xMin(),3));
      double sum1(0.),sum2(0.),sum3(0.),sum4(0.),sum5(0.);
      for (const auto& bin : hist->bins()) {
        double Oi = bin.sumW();
        if (Oi==0.) continue;
        double a =  d*(bin.xMax() - bin.xMin());
        double b = d/3.*(pow(bin.xMax(),3) - pow(bin.xMin(),3));
        double Ei = bin.errW();
        sum1 +=   a*Oi/sqr(Ei);
        sum2 +=   b*Oi/sqr(Ei);
        sum3 += sqr(a)/sqr(Ei);
        sum4 += sqr(b)/sqr(Ei);
        sum5 +=    a*b/sqr(Ei);
      }
      // calculate alpha
      double alpha = (-c*sum1 + sqr(c)*sum2 + sum3 - c*sum5)/(sum1 - c*sum2 + c*sum4 - sum5);
      // and error
      double cc = -pow((sum3 + sqr(c)*sum4 - 2*c*sum5),3);
      double bb = -2*sqr(sum3 + sqr(c)*sum4 - 2*c*sum5)*(sum1 - c*sum2 + c*sum4 - sum5);
      double aa =  sqr(sum1 - c*sum2 + c*sum4 - sum5)*(-sum3 - sqr(c)*sum4 + sqr(sum1 - c*sum2 + c*sum4 - sum5) + 2*c*sum5);
      double dis = sqr(bb)-4.*aa*cc;
      if (dis>0.) {
        dis = sqrt(dis);
        return make_pair(alpha,make_pair(0.5*(-bb+dis)/aa,-0.5*(-bb-dis)/aa));
      }
      else {
        return make_pair(alpha,make_pair(0.,0.));
      }
    }

    pair<double,double> calcCoeff(unsigned int imode, Histo1DPtr hist) const {
      if (hist->numEntries()==0.) return make_pair(0.,0.);
      double sum1(0.), sum2(0.);
      for (const auto& bin : hist->bins()) {
        double Oi = bin.sumW();
        if (Oi==0.) continue;
        double ai(0.),bi(0.);
        if (imode==0) {
          bi = (pow(1.-sqr(bin.xMin()),1.5) - pow(1.-sqr(bin.xMax()),1.5))/3.;
        }
        else if (imode>=2 && imode<=4) {
          bi = (  pow(bin.xMin(),3) * (-5. + 3.*sqr(bin.xMin()))
                + pow(bin.xMax(),3) * ( 5. - 3.*sqr(bin.xMax())))/15.;
        }
        else {
          assert(false);
        }
        double Ei = bin.errW();
        sum1 += sqr(bi/Ei);
        sum2 += bi/sqr(Ei)*(Oi-ai);
      }
      return make_pair(sum2/sum1,sqrt(1./sum1));
    }

    /// Normalise histograms etc., after the run
    void finalize() {
      normalize(_h_cThetaL);
      for (unsigned int ix=0;ix<2;++ix) {
        scale(_h_T1[ix], 1./ *_wsum[ix]);
        scale(_h_T2[ix], 1./ *_wsum[ix]);
        scale(_h_T3[ix], 1./ *_wsum[ix]);
        scale(_h_T4[ix], 1./ *_wsum[ix]);
        scale(_h_T5[ix], 1./ *_wsum[ix]);
      }
      // first calculate alpha for J/psi -> Sigma+ Sigmabar-
      pair<double,pair<double,double> > alphaPsi = calcAlpha0(_h_cThetaL);
      Estimate0DPtr h_alphaPsi;
      book(h_alphaPsi,1,1,1);
      h_alphaPsi->set(alphaPsi.first, alphaPsi.second);
      double s2 = -1. + sqr(alphaPsi.first);
      double s3 = 3 + alphaPsi.first;
      double s1 = sqr(s3);
      pair<double,pair<double,double> > alpha0     = make_pair(0.,make_pair(0.,0.));
      pair<double,pair<double,double> > alphabar0  = make_pair(0.,make_pair(0.,0.));
      pair<double,pair<double,double> > alphaplus  = make_pair(0.,make_pair(0.,0.));
      pair<double,pair<double,double> > alphaminus = make_pair(0.,make_pair(0.,0.));
      pair<double,pair<double,double> > delta      = make_pair(0.,make_pair(0.,0.));

      // now for the Sigma decays
      for (unsigned int ix=0; ix<2; ++ix) {
        pair<double,double> c_T2 = calcCoeff(2,_h_T2[ix]);
        pair<double,double> c_T3 = calcCoeff(3,_h_T3[ix]);
        pair<double,double> c_T4 = calcCoeff(4,_h_T4[ix]);
        double s4 = sqr(c_T2.first);
        double s5 = sqr(c_T3.first);
        double s6 = sqr(c_T4.first);
        double disc = s1*s5*s6*(-9.*s2*s4 + 4.*s1*s5*s6);
        if (disc<0.) continue;
        disc = sqrt(disc);
        double aM = -sqrt(-1./s2/s6*(2.*s1*s5*s6+disc));
        if (ix==1) aM *=-1;
        double aP = c_T4.first/c_T3.first*aM;
        double aM_P = (2*(alphaPsi.first*c_T4.first*alphaPsi.second.first + c_T4.second*s2)*(disc + 2*s1*s5*s6)
                       - c_T4.first*s2*(4*s3*c_T3.first*c_T4.first*(c_T3.first*c_T4.first*alphaPsi.second.first
                       +s3*c_T4.first*c_T3.second +s3*c_T3.first*c_T4.second) +
                       (disc*(- 9*s2*s3*c_T2.first*c_T3.first*c_T4.first* c_T2.second
                       + 9*((1 - alphaPsi.first*(3 + 2*alphaPsi.first))* c_T3.first*c_T4.first*alphaPsi.second.first
                       - s2*s3*c_T4.first*c_T3.second - s2*s3*c_T3.first*c_T4.second)* s4
                       + 8*(c_T3.first*c_T4.first*alphaPsi.second.first + s3*c_T4.first*c_T3.second
                       + s3*c_T3.first*c_T4.second)* s1*s5*s6))
                       /(4*pow(3 + alphaPsi.first,3)*pow(c_T3.first,3)*pow(c_T4.first,3)
                       - 9*s2*s3*c_T3.first*c_T4.first*s4)))
                       / (2.*pow(c_T4.first,3)*pow(s2,2)*sqrt(-((disc + 2*s1*s5*s6)/(s2*s6))));
        double aM_M = (2*(alphaPsi.first*c_T4.first*alphaPsi.second.second + c_T4.second*s2)*(disc + 2*s1*s5*s6)
                       - c_T4.first*s2*(4*s3*c_T3.first*c_T4.first*(c_T3.first*c_T4.first*alphaPsi.second.second
                       + s3*c_T4.first*c_T3.second +s3*c_T3.first*c_T4.second) +
                       (disc*(- 9*s2*s3*c_T2.first*c_T3.first*c_T4.first* c_T2.second
                       + 9*((1 - alphaPsi.first*(3 + 2*alphaPsi.first))* c_T3.first*c_T4.first*alphaPsi.second.second
                       - s2*s3*c_T4.first*c_T3.second - s2*s3*c_T3.first*c_T4.second)* s4
                       + 8*(c_T3.first*c_T4.first*alphaPsi.second.second +  s3*c_T4.first*c_T3.second
                       + s3*c_T3.first*c_T4.second)* s1*s5*s6))
                       / (4*pow(3 + alphaPsi.first,3)*pow(c_T3.first,3)*pow(c_T4.first,3)
                       - 9*s2*s3*c_T3.first*c_T4.first*s4)))
                       /(2.*pow(c_T4.first,3)*pow(s2,2)*sqrt(-((disc + 2*s1*s5*s6)/(s2*s6))));
        double aP_M = (c_T4.first*sqrt(-((disc + 2*s1*s5*s6)
                      / (s2*s6)))* (-2*c_T3.second - (2*alphaPsi.first*c_T3.first*alphaPsi.second.first)/s2
                      + (c_T3.first*(4*s3*c_T3.first*c_T4.first*(c_T3.first*c_T4.first*alphaPsi.second.first
                      + s3*c_T4.first*c_T3.second + s3*c_T3.first*c_T4.second)
                      + (disc*(-9*s2*s3*c_T2.first*c_T3.first*c_T4.first* c_T2.second
                      + 9*((1 - alphaPsi.first*(3 + 2*alphaPsi.first))* c_T3.first*c_T4.first*alphaPsi.second.first
                      - s2*s3*c_T4.first*c_T3.second - s2*s3*c_T3.first*c_T4.second)* s4
                      + 8*(c_T3.first*c_T4.first*alphaPsi.second.first + s3*c_T4.first*c_T3.second
                      + s3*c_T3.first*c_T4.second)* s1*s5*s6))
                      / (4* pow(3 + alphaPsi.first,3)* pow(c_T3.first,3)* pow(c_T4.first,3)
                      -  9*s2*s3*c_T3.first*c_T4.first*s4)))
                      / (disc + 2*s1*s5*s6)))/(2.*pow(c_T3.first,2));
        double aP_P = (c_T4.first*sqrt(-((disc + 2*s1*s5*s6)/(s2*s6)))
                       * (-2*c_T3.second - (2*alphaPsi.first*c_T3.first*alphaPsi.second.second)/s2
                       + (c_T3.first*(4*s3*c_T3.first*c_T4.first*(c_T3.first*c_T4.first*alphaPsi.second.second
                       + s3*c_T4.first*c_T3.second + s3*c_T3.first*c_T4.second)
                       + (disc*(-9*s2*s3*c_T2.first*c_T3.first*c_T4.first* c_T2.second
                       + 9*((1 - alphaPsi.first*(3 + 2*alphaPsi.first))* c_T3.first*c_T4.first*alphaPsi.second.second
                       - s2*s3*c_T4.first*c_T3.second - s2*s3*c_T3.first*c_T4.second)* s4
                       + 8*(c_T3.first*c_T4.first*alphaPsi.second.second + s3*c_T4.first*c_T3.second
                       + s3*c_T3.first*c_T4.second)* s1*s5*s6))
                       / (4* pow(3 + alphaPsi.first,3)* pow(c_T3.first,3) * pow(c_T4.first,3)
                       - 9*s2*s3*c_T3.first*c_T4.first*s4)))
                       / (disc + 2*s1*s5*s6)))/(2.*pow(c_T3.first,2));
        if (ix==0) {
          alphaminus = make_pair(aP, make_pair(-aP_M , -aP_P));
          alpha0     = make_pair(aM, make_pair(-aM_M , -aM_P));
        }
        else {
          alphabar0 = make_pair(aP, make_pair(-aP_M , -aP_P));
          alphaplus = make_pair(aM, make_pair(-aM_M , -aM_P));
        }
        // now for Delta
        double sDelta = (-2.*(3. + alphaPsi.first)*c_T3.first)/(aM*sqrt(1 - sqr(alphaPsi.first)));
        double cDelta = (-3*(3 + alphaPsi.first)*c_T2.first)/(aM*aP*sqrt(1 - sqr(alphaPsi.first)));

        double Delta = asin(sDelta);
        if (cDelta<0.) Delta = M_PI-Delta;
        double ds_P = (-9*c_T2.first*((-1 + alphaPsi.first)*(1 + alphaPsi.first)
                      * (3 + alphaPsi.first)*c_T3.first*c_T4.first*c_T2.second
                      + c_T2.first*c_T4.first*(c_T3.first*(alphaPsi.second.first
                      + 3*alphaPsi.first*alphaPsi.second.first) -(-1 + alphaPsi.first)
                      * (1 + alphaPsi.first)*(3 + alphaPsi.first)*c_T3.second)
                      - (-1 + alphaPsi.first)*(1 + alphaPsi.first)
                      * (3 + alphaPsi.first)*c_T2.first*c_T3.first*c_T4.second)*disc)
                      / (pow(1 - pow(alphaPsi.first,2),1.5)*pow(c_T4.first,3)*pow(-((disc + 2*s1*s5*s6)
                      / (s2*s6)),1.5)*(-9*s2*s4 + 4*s1*s5*s6));
        double ds_M = (-9*c_T2.first*((-1 + alphaPsi.first)*(1 + alphaPsi.first)
                      * (3 + alphaPsi.first)*c_T3.first*c_T4.first*c_T2.second
                      + c_T2.first*c_T4.first*(c_T3.first*(alphaPsi.second.second
                      + 3*alphaPsi.first*alphaPsi.second.second) -(-1 + alphaPsi.first)
                      * (1 + alphaPsi.first)*(3 + alphaPsi.first)*c_T3.second)
                      - (-1 + alphaPsi.first)*(1 + alphaPsi.first)
                      * (3 + alphaPsi.first)*c_T2.first*c_T3.first*c_T4.second)*disc)
                      / (pow(1 - pow(alphaPsi.first,2),1.5)*pow(c_T4.first,3)*pow(-((disc + 2*s1*s5*s6)
                      / (s2*s6)),1.5)*(-9*s2*s4 + 4*s1*s5*s6));
        ds_P /= sqrt(1.-sqr(sDelta));
        ds_M /= sqrt(1.-sqr(sDelta));
        delta.first += Delta;
        delta.second.first  += sqr(ds_P);
        delta.second.second += sqr(ds_M);
      }
      // delta phi
      Estimate0DPtr h_delta;
      book(h_delta,1,1,2);
      delta.first *= 0.5;
      delta.second.first  = 0.5*sqrt(delta.second.first );
      delta.second.second = 0.5*sqrt(delta.second.second);
      h_delta->set(delta.first, delta.second);
      // alphas
      Estimate0DPtr h_alphaP;
      book(h_alphaP,1,1,3);
      h_alphaP->set(alphaplus.first,alphaplus.second);
      Estimate0DPtr h_alphaM;
      book(h_alphaM,1,1,4);
      h_alphaM->set(alphaminus.first,alphaminus.second);
      Estimate0DPtr h_alpha0,h_alphabar0;
      book(h_alpha0,"TMP/h_alpha0");
      h_alpha0->set(alpha0.first,alpha0.second);
      book(h_alphabar0,"TMP/h_alphabar0");
      h_alphabar0->set(alphabar0.first,alphabar0.second);
      // ratios
      Estimate0DPtr rplus;
      book(rplus,1,1,5);
      divide(h_alphaP,h_alpha0,rplus);
      rplus->setPath("/"+name()+"/"+mkAxisCode(1,1,5));
      Estimate0DPtr rminus;
      book(rminus,1,1,6);
      divide(h_alphaM,h_alphabar0,rminus);
      rminus->setPath("/"+name()+"/"+mkAxisCode(1,1,6));
      //average
      Estimate0DPtr aver;
      book(aver,1,1,8);
      aver->setVal(0.5*(alphaplus.first-alphaminus.first));
      aver->setErr(make_pair(0.5*sqrt(sqr(alphaplus.second.first )+sqr(alphaminus.second.second)),
                             0.5*sqrt(sqr(alphaplus.second.second)+sqr(alphaminus.second.first ))));
    }

    /// @}

    /// @name Histograms
    /// @{
    Histo1DPtr _h_T1[2],_h_T2[2],_h_T3[2],_h_T4[2],_h_T5[2];
    Histo1DPtr _h_cThetaL;
    CounterPtr _wsum[2];
    /// @}

  };


  RIVET_DECLARE_PLUGIN(BESIII_2023_I2655292);

}
