/*
 * ==========================================================================
 *         _   _      _   ____            __ __  __      _
 *        | \ | | ___| |_|  _ \ ___ _ __ / _|  \/  | ___| |_ ___ _ __
 *        |  \| |/ _ \ __| |_) / _ \ '__| |_| |\/| |/ _ \ __/ _ \ '__|
 *        | |\  |  __/ |_|  __/  __/ |  |  _| |  | |  __/ ||  __/ |
 *        |_| \_|\___|\__|_|   \___|_|  |_| |_|  |_|\___|\__\___|_|
 *
 *                  NetPerfMeter -- Network Performance Meter
 *                 Copyright (C) 2009-2026 by Thomas Dreibholz
 * ==========================================================================
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.

 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 *
 * Contact:  dreibh@simula.no
 * Homepage: https://www.nntb.no/~dreibh/netperfmeter/
 */

#include "control.h"
#include "loglevel.h"
#include "tools.h"

#include <cstring>


#define MAXIMUM_MESSAGE_SIZE (size_t)65536
#define MAXIMUM_PAYLOAD_SIZE (MAXIMUM_MESSAGE_SIZE - sizeof(NetPerfMeterDataMessage))


static void updateStatistics(Flow*                          flowSpec,
                             const unsigned long long       now,
                             const NetPerfMeterDataMessage* dataMsg,
                             const size_t                   received);



// ###### Generate payload pattern ##########################################
static void fillPayload(unsigned char* payload,
                        const size_t   length,
                        const bool     reverse = false)
{
   // ====== Pattern for Active -> Passive transfer =========================
   if(!reverse) {
      unsigned char c = 30;
      for(size_t i = 0;i < length;i++) {
         *payload++ = c++;
         if(c > 127) {
            c = 30;
         }
      }
   }
   // ====== Pattern for Passive -> Active transfer =========================
   else {
      unsigned char c = 127;
      for(size_t i = 0;i < length;i++) {
         *payload++ = c--;
         if(c < 30) {
            c = 127;
         }
      }
   }
}


// ###### Send NETPERFMETER_DATA message ####################################
ssize_t sendNetPerfMeterData(Flow*                    flow,
                             const uint32_t           frameID,
                             const bool               isFrameBegin,
                             const bool               isFrameEnd,
                             const unsigned long long now,
                             size_t                   bytesToSend)
{
   char                     outputBuffer[MAXIMUM_MESSAGE_SIZE];
   NetPerfMeterDataMessage* dataMsg = (NetPerfMeterDataMessage*)&outputBuffer;

   if(bytesToSend < sizeof(NetPerfMeterDataMessage)) {
      bytesToSend = sizeof(NetPerfMeterDataMessage);
   }

   // ====== Prepare NETPERFMETER_DATA message ==============================
   // ------ Create header --------------------------------
   dataMsg->Header.Type   = NETPERFMETER_DATA;
   dataMsg->Header.Flags  = 0x00;
   if(isFrameBegin) {
      dataMsg->Header.Flags |= NPMDF_FRAME_BEGIN;
   }
   if(isFrameEnd) {
      dataMsg->Header.Flags |= NPMDF_FRAME_END;
   }
   dataMsg->Header.Length = htons(bytesToSend);
   dataMsg->MeasurementID = hton64(flow->getMeasurementID());
   dataMsg->FlowID        = htonl(flow->getFlowID());
   dataMsg->StreamID      = htons(flow->getStreamID());
   dataMsg->Padding       = 0x0000;
   dataMsg->FrameID       = htonl(frameID);
   dataMsg->SeqNumber     = hton64(flow->nextOutboundSeqNumber());
   dataMsg->ByteSeqNumber = hton64(flow->getCurrentBandwidthStats().TransmittedBytes);
   dataMsg->TimeStamp     = hton64(now);

   // ------ Create payload data pattern ------------------
   fillPayload((unsigned char*)&dataMsg->Payload,
               bytesToSend - sizeof(NetPerfMeterDataMessage),
               flow->isAcceptedIncomingFlow());

   // ====== Send NETPERFMETER_DATA message =================================
   ssize_t sent;
   if(flow->getTrafficSpec().Protocol == IPPROTO_SCTP) {
      sctp_sndrcvinfo sinfo;
      memset(&sinfo, 0, sizeof(sinfo));
      sinfo.sinfo_stream   = flow->getStreamID();
      sinfo.sinfo_ppid     = htonl(PPID_NETPERFMETER_DATA);
      if(flow->getTrafficSpec().ReliableMode < 1.0) {
         const bool sendUnreliable = (randomDouble() > flow->getTrafficSpec().ReliableMode);
         if(sendUnreliable) {
            sinfo.sinfo_timetolive = flow->getTrafficSpec().RetransmissionTrials;
            if(flow->getTrafficSpec().RetransmissionTrialsInMS) {
               sinfo.sinfo_flags |= SCTP_PR_SCTP_TTL;
            }
            else {
               sinfo.sinfo_flags |= SCTP_PR_SCTP_RTX;
            }
         }
      }
      if(flow->getTrafficSpec().OrderedMode < 1.0) {
         const bool sendUnordered = (randomDouble() > flow->getTrafficSpec().OrderedMode);
         if(sendUnordered) {
            sinfo.sinfo_flags |= SCTP_UNORDERED;
         }
      }
      sent = sctp_send(flow->getSocketDescriptor(),
                       (char*)&outputBuffer, bytesToSend,
                       &sinfo, 0);
   }
#ifdef HAVE_QUIC
   else if(flow->getTrafficSpec().Protocol == IPPROTO_QUIC) {
      int64_t  sid;
      uint32_t flags;
      if(flow->isAcceptedIncomingFlow()) {   // from passive side (server)
         sid   = ((int64_t)flow->getStreamID() << 2) | QUIC_STREAM_TYPE_SERVER_MASK | QUIC_STREAM_TYPE_UNI_MASK;
         flags = (flow->getFirstTransmission() == 0) ? MSG_QUIC_STREAM_NEW : 0;
      }
      else {
         sid   = ((int64_t)flow->getStreamID() << 2) | QUIC_STREAM_TYPE_UNI_MASK;
         flags = 0;   // already created by Identify procedure!
      }
      sent = quic_sendmsg(flow->getSocketDescriptor(), (char*)&outputBuffer, bytesToSend, sid, flags);
   }
#endif
   else if(flow->getTrafficSpec().Protocol == IPPROTO_UDP) {
      if(flow->isRemoteAddressValid()) {
         sent = ext_sendto(flow->getSocketDescriptor(),
                           (char*)&outputBuffer, bytesToSend, 0,
                           flow->getRemoteAddress(),
                           getSocklen(flow->getRemoteAddress()));
      }
      else {
         sent = ext_send(flow->getSocketDescriptor(),
                         (char*)&outputBuffer, bytesToSend, 0);
      }
   }
   else {
      sent = ext_send(flow->getSocketDescriptor(), (char*)&outputBuffer, bytesToSend, 0);
   }

   // ====== Check, whether flow has been aborted unintentionally ===========
   if((sent < 0) &&
      (errno != EAGAIN) &&
      (!flow->isAcceptedIncomingFlow()) &&
      (flow->getTrafficSpec().ErrorOnAbort) &&
      (flow->getOutputStatus() == Flow::On)) {
      LOG_FATAL
      stdlog << format("Flow #%u on socket %d has been aborted: %s!",
                       flow->getFlowID(), flow->getSocketDescriptor(),
                       strerror(errno)) << "\n";
      LOG_END_FATAL
   }

   return sent;
}


// ###### Transmit data frame ###############################################
bool transmitFrame(Flow*                    flow,
                   const unsigned long long now)
{
   // ====== Obtain length of data to send ==================================
   ssize_t bytesToSend =
      (ssize_t)rint(getRandomValue((const double*)&flow->getTrafficSpec().OutboundFrameSize,
                                   flow->getTrafficSpec().OutboundFrameSizeRng));
   ssize_t bytesSent   = 0;
   size_t  packetsSent = 0;
   if(bytesToSend > 0) {
      const uint32_t frameID = flow->nextOutboundFrameID();
      while(bytesSent < (ssize_t)bytesToSend) {
         // ====== Send message =============================================
         ssize_t chunkSize = std::min(bytesToSend - bytesSent,
                                      std::min((ssize_t)flow->getTrafficSpec().MaxMsgSize,
                                               (ssize_t)MAXIMUM_MESSAGE_SIZE));
         const ssize_t sent =
            sendNetPerfMeterData(flow, frameID,
                                 (bytesSent == 0),                       // Is frame begin?
                                 (bytesSent + chunkSize >= bytesToSend), // Is frame end?
                                 now, chunkSize);
         // NOTE: Due to minimum size for a NETPERFMETER_DATA chunk (48 B),
         //       the sent size may be >= chunkSize!

         // ====== Update statistics ========================================
         if(sent > 0) {
            bytesSent += sent;
            packetsSent++;
         }
         else {
            // Transmission error -> stop sending.
            break;
         }
      }
   }
   else {
      bytesToSend = 0;   // There is nothing to send
   }

   // ====== Update statistics ==============================================
   flow->updateTransmissionStatistics(now, 1, packetsSent, bytesSent);
   return (bytesSent >= bytesToSend);
}


// ###### Handle data message ###############################################
bool handleNetPerfMeterData(const bool               isActiveMode,
                            const unsigned long long now,
                            const int                protocol,
                            const int                sd)
{
   char            inputBuffer[65536];
   sockaddr_union  from;
   socklen_t       fromlen  = sizeof(from);
   int             flags    = 0;
   int64_t         streamID = 0;

   // ====== Read message (or fragment) =====================================
   const ssize_t received =
      FlowManager::getFlowManager()->getMessageReader()->receiveMessage(
         sd, &inputBuffer, sizeof(inputBuffer), &from.sa, &fromlen, &streamID, &flags);
   if(received == MRRM_PARTIAL_READ) {
      return true;   // Partial read -> wait for next fragment.
   }

   // ====== Handle data ====================================================
   if(received > 0) {
      if(!(flags & MSG_NOTIFICATION)) {
         const NetPerfMeterDataMessage*     dataMsg     =
            (const NetPerfMeterDataMessage*)&inputBuffer;
         const NetPerfMeterIdentifyMessage* identifyMsg =
            (const NetPerfMeterIdentifyMessage*)&inputBuffer;

         // ====== Handle NETPERFMETER_IDENTIFY_FLOW message ================
         if( (received >= (ssize_t)sizeof(NetPerfMeterIdentifyMessage)) &&
            (identifyMsg->Header.Type == NETPERFMETER_IDENTIFY_FLOW) &&
            (ntoh64(identifyMsg->MagicNumber) == NETPERFMETER_IDENTIFY_FLOW_MAGIC_NUMBER) ) {
            const bool identifyOkay =
               handleNetPerfMeterIdentify(identifyMsg, sd, &from);
            if(!identifyOkay) {
               LOG_WARNING
               stdlog << format("Failed handling NETPERFMETER_IDENTIFY on socket %d!", sd) << "\n";
               LOG_END
               if(protocol != IPPROTO_UDP) {
                  ext_shutdown(sd, SHUT_RDWR);
               }
               return false;
            }
         }

         // ====== Handle NETPERFMETER_DATA message =========================
         else if( (received >= (ssize_t)sizeof(NetPerfMeterDataMessage)) &&
                  (dataMsg->Header.Type == NETPERFMETER_DATA) ) {
            // ====== Identify flow =========================================
            Flow* flow;
            if( (protocol == IPPROTO_UDP) && (!isActiveMode) ) {
               flow = FlowManager::getFlowManager()->findFlow(&from.sa);
            }
            else if(protocol == IPPROTO_SCTP) {
               // Flow ID = SCTP stream ID:
               flow = FlowManager::getFlowManager()->findFlow(sd, (uint32_t)streamID);
            }
#ifdef HAVE_QUIC
            else if(protocol == IPPROTO_QUIC) {
               // Flow ID = QUIC stream ID (remove 2 last bits:
               const uint32_t flowID = (uint32_t)(streamID >> 2);
               flow = FlowManager::getFlowManager()->findFlow(sd, flowID);
            }
#endif
            else {
               // Flow ID is always 0:
               flow = FlowManager::getFlowManager()->findFlow(sd, 0);
            }
            if(flow) {
               // Update flow statistics by received NETPERFMETER_DATA message.
               updateStatistics(flow, now, dataMsg, received);
            }
            else {
               LOG_WARNING
               stdlog << format("Received NETPERFMETER_DATA for unknown flow on socket %d!", sd) << "\n";
               LOG_END
               if(protocol != IPPROTO_UDP) {
                  ext_shutdown(sd, SHUT_RDWR);
               }
               return false;
            }
         }
         else {
            LOG_WARNING
            stdlog << format("Received garbage on socket %d!", sd) << "\n";
            LOG_END
            if(protocol != IPPROTO_UDP) {
               ext_shutdown(sd, SHUT_RDWR);
            }
            return false;
         }
      }
   }

   // ====== Handle error ===================================================
   else {
      Flow* flow = FlowManager::getFlowManager()->findFlow(sd, 0);
      if(flow) {
         flow->lock();
         if(!flow->isAcceptedIncomingFlow()) {
            // The outgoing flow on the active side is closed:
            LOG_WARNING
            stdlog << format("End of input for flow #%u on socket %d!",
                             flow->getFlowID(), sd) << "\n";
            LOG_END
         }
         else {
            // This is probably just the regular connection shutdown:
            LOG_DEBUG
            stdlog << format("End of input for flow #%u on socket %d!",
                             flow->getFlowID(), sd) << "\n";
            LOG_END
         }
         flow->unlock();
         flow->endOfInput();
         return flow->isAcceptedIncomingFlow();   // No error for incoming flow!
      }
      else {
         LOG_WARNING
         stdlog << format("End of input for unidentified flow on socket %d!",
                          sd) << "\n";
         LOG_END
      }
      if(protocol != IPPROTO_UDP) {
         ext_shutdown(sd, SHUT_RDWR);
      }
      return false;
   }

   return true;
}


// ###### Update flow statistics with incoming NETPERFMETER_DATA message ####
static void updateStatistics(Flow*                          flow,
                             const unsigned long long       now,
                             const NetPerfMeterDataMessage* dataMsg,
                             const size_t                   receivedBytes)
{
   // ====== Update QoS statistics ==========================================
   const uint64_t seqNumber   = ntoh64(dataMsg->SeqNumber);
   const uint64_t timeStamp   = ntoh64(dataMsg->TimeStamp);
   const double   transitTime = ((double)now - (double)timeStamp) / 1000.0;

   // ------ Jitter calculation according to RFC 3550 -----------------------
   /* From RFC 3550:
      int transit = arrival - r->ts;
      int d = transit - s->transit;
      s->transit = transit;
      if (d < 0) d = -d;
      s->jitter += (1./16.) * ((double)d - s->jitter);
   */
   const double diff   = transitTime - flow->getDelay();
   const double jitter = flow->getJitter() + (1.0/16.0) * (fabs(diff) - flow->getJitter());

   // ------ Loss calculation -----------------------------------------------
   flow->getDefragmenter()->addFragment(now, dataMsg);
   size_t receivedFrames;
   size_t lostFrames;
   size_t lostPackets;
   size_t lostBytes;
   flow->getDefragmenter()->purge(now, flow->getTrafficSpec().DefragmentTimeout,
                                  receivedFrames, lostFrames, lostPackets, lostBytes);

   flow->updateReceptionStatistics(
      now, receivedFrames, receivedBytes,
      lostFrames, lostPackets, lostBytes,
      (unsigned long long)seqNumber, transitTime, diff, jitter);
}
