#!/usr/bin/env python3
# ==========================================================================
#     _   _ _ ____            ____          _____
#    | | | (_)  _ \ ___ _ __ / ___|___  _ _|_   _| __ __ _  ___ ___ _ __
#    | |_| | | |_) / _ \ '__| |   / _ \| '_ \| || '__/ _` |/ __/ _ \ '__|
#    |  _  | |  __/  __/ |  | |__| (_) | | | | || | | (_| | (_|  __/ |
#    |_| |_|_|_|   \___|_|   \____\___/|_| |_|_||_|  \__,_|\___\___|_|
#
#       ---  High-Performance Connectivity Tracer (HiPerConTracer)  ---
#                 https://www.nntb.no/~dreibh/hipercontracer/
# ==========================================================================
#
# High-Performance Connectivity Tracer (HiPerConTracer)
# Copyright (C) 2015-2025 by Thomas Dreibholz
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Contact: dreibh@simula.no

import sys
import datetime
import ipaddress

import QueryHelper


# ###### Main program #######################################################
if len(sys.argv) < 3:
   sys.stderr.write('Usage: ' + sys.argv[0] + ' database_configuration ping|traceroute from_time to_time\n')
   sys.exit(1)
configurationFile = sys.argv[1]
queryType         = sys.argv[2]
fromTimeStamp     = None
toTimeStamp       = None
if len(sys.argv) > 3:
   fromTimeStamp = datetime.datetime.strptime(sys.argv[3], "%Y-%m-%d %H:%M:%S")
   if len(sys.argv) > 4:
      toTimeStamp = datetime.datetime.strptime(sys.argv[4], "%Y-%m-%d %H:%M:%S")

# ====== Connect to database ================================================
configuration = QueryHelper.DatabaseConfiguration(sys.argv[1])
client = configuration.createClient()


# ====== Query Ping data ====================================================
if queryType == 'ping':
   # ------ SQL: MySQL/MariaDB/PostgreSQL -----------------------------------
   if configuration.getBackend() in [ 'MySQL', 'MariaDB', 'PostgreSQL' ]:
      query = 'SELECT SendTimestamp,MeasurementID,SourceIP,DestinationIP,Protocol,TrafficClass,BurstSeq,PacketSize,ResponseSize,Checksum,Status,TimeSource,Delay_AppSend,Delay_Queuing, Delay_AppReceive,RTT_App,RTT_SW,RTT_HW FROM Ping ORDER BY Timestamp,MeasurementID,SourceIP,DestinationIP,Protocol,TrafficClass,RoundNumber,HopNumber'
      if fromTimeStamp != None:
         query += ' WHERE TimeStamp >= \'' + str(fromTimeStamp) + '\''
         if toTimeStamp != None:
            query += ' AND TimeStamp < \'' + str(toTimeStamp) + '\''
      print(query)
      rows = configuration.query(query)

      for row in rows:
         timeStamp     = round(row[0].replace(tzinfo=datetime.timezone.utc).timestamp() * 1000000.0)
         sourceIP      = QueryHelper.unmap(ipaddress.ip_address(row[1]))
         destinationIP = QueryHelper.unmap(ipaddress.ip_address(row[2]))
         checksum      = int(row[3])
         packetSize    = int(row[4])
         trafficClass  = int(row[5])
         status        = int(row[6])
         rtt           = int(row[7])

         sys.stdout.write('#P ' +
            str(sourceIP)         + ' ' +
            str(destinationIP)    + ' ' +
            hex(timeStamp)[2:]    + ' ' +
            hex(checksum)[2:]     + ' ' +
            str(status)           + ' ' +
            str(rtt)              + ' ' +
            hex(trafficClass)[2:] + ' ' +
            str(packetSize)       + '\n')

   # ------ NoSQL: MongoDB --------------------------------------------------
   elif configuration.getBackend() == 'MongoDB':
      query = { }
      if fromTimeStamp != None:
         ft =  { 'timestamp': { '$gte': str(fromTimeStamp) } }
         if toTimeStamp != None:
            tt =  { 'timestamp': { '$lt': str(toTimeStamp) } }
            query = { '$and': [ ft, tt ] }
         else:
            query = ft
      # print(query)
      rows = configuration.queryMongoDB('ping', query)

      for row in rows:
         timeStamp     = int(row['timestamp'])
         sourceIP      = ipaddress.ip_address(bytes(row['source']))
         destinationIP = ipaddress.ip_address(row['destination'])
         checksum      = int(row['checksum'])
         packetSize    = int(row['pktsize'])
         trafficClass  = int(row['tc'])
         status        = int(row['status'])
         rtt           = int(row['rtt'])
         sys.stdout.write('#P ' +
            str(sourceIP)         + ' ' +
            str(destinationIP)    + ' ' +
            hex(timeStamp)[2:]    + ' ' +
            hex(checksum)[2:]     + ' ' +
            str(status)           + ' ' +
            str(rtt)              + ' ' +
            hex(trafficClass)[2:] + ' ' +
            str(packetSize)       + '\n')


# ====== Query Traceroute data ==============================================
elif queryType == 'traceroute':
   # ------ SQL: MySQL/MariaDB/PostgreSQL -----------------------------------
   if configuration.getBackend() in [ 'MySQL', 'MariaDB', 'PostgreSQL' ]:
      query = 'SELECT Timestamp,MeasurementID,SourceIP,DestinationIP,Protocol,TrafficClass,RoundNumber,HopNumber,TotalHops,PacketSize,ResponseSize,Checksum,Status,PathHash,SendTimestamp,HopIP,TimeSource,Delay_AppSend,Delay_Queuing,Delay_AppReceive,RTT_App,RTT_SW,RTT_HW FROM Traceroute ORDER BY Timestamp,MeasurementID,SourceIP,DestinationIP,Protocol,TrafficClass,RoundNumber,HopNumber'
      if fromTimeStamp != None:
         query += ' WHERE TimeStamp >= \'' + str(fromTimeStamp) + '\''
         if toTimeStamp != None:
            query += ' AND TimeStamp < \'' + str(toTimeStamp) + '\''
      rows = configuration.query(query)

      for row in rows:
         timeStamp     = round(row[0].replace(tzinfo=datetime.timezone.utc).timestamp() * 1000000.0)
         sourceIP      = QueryHelper.unmap(ipaddress.ip_address(row[1]))
         destinationIP = QueryHelper.unmap(ipaddress.ip_address(row[2]))
         roundNumber   = int(row[3])
         checksum      = int(row[4])
         packetSize    = int(row[5])
         trafficClass  = int(row[6])
         hopNumber     = int(row[7])
         totalHops     = int(row[8])
         status        = int(row[9])
         rtt           = int(row[10])
         hopIP         = QueryHelper.unmap(ipaddress.ip_address(row[11]))
         pathHash      = int(row[12])
         if pathHash < 0:
            pathHash = 0x10000000000000000 - abs(pathHash)

         if hopNumber == 1:
            statusFlags = status - (status & 0xff)
            sys.stdout.write('#T ' +
               str(sourceIP)         + ' ' +
               str(destinationIP)    + ' ' +
               hex(timeStamp)[2:]    + ' ' +
               str(roundNumber)      + ' ' +
               hex(checksum)[2:]     + ' ' +
               str(totalHops)        + ' ' +
               hex(statusFlags)[2:]  + ' ' +
               hex(pathHash)[2:]     + ' ' +
               hex(trafficClass)[2:] + ' ' +
               str(packetSize)       + '\n')
         sys.stdout.write('\t ' +
            str(hopNumber)         + ' ' +
            hex(status & 0xff)[2:] + ' ' +
            str(rtt)               + ' ' +
            str(hopIP)             + '\n')

   # ------ NoSQL: MongoDB --------------------------------------------------
   elif configuration.getBackend() == 'MongoDB':
      query = { }
      if fromTimeStamp != None:
         ft =  { 'timestamp': { '$gte': str(fromTimeStamp) } }
         if toTimeStamp != None:
            tt =  { 'timestamp': { '$lt': str(toTimeStamp) } }
            query = { '$and': [ ft, tt ] }
         else:
            query = ft
      rows = configuration.queryMongoDB('traceroute', query)

      for row in rows:
         timeStamp     = int(row['timestamp'])
         sourceIP      = ipaddress.ip_address(bytes(row['source']))
         destinationIP = ipaddress.ip_address(bytes(row['destination']))
         roundNumber   = int(row['round'])
         checksum      = int(row['checksum'])
         packetSize    = int(row['pktsize'])
         trafficClass  = int(row['tc'])
         totalHops     = int(row['totalHops'])
         statusFlags   = int(row['statusFlags'])
         pathHash      = int(row['pathHash'])
         if pathHash < 0:
            pathHash = 0x10000000000000000 - abs(pathHash)
         sys.stdout.write('#T ' +
            str(sourceIP)         + ' ' +
            str(destinationIP)    + ' ' +
            hex(timeStamp)[2:]    + ' ' +
            str(roundNumber)      + ' ' +
            hex(checksum)[2:]     + ' ' +
            str(totalHops)        + ' ' +
            hex(statusFlags)[2:]  + ' ' +
            hex(pathHash)[2:]     + ' ' +
            hex(trafficClass)[2:] + ' ' +
            str(packetSize)       + '\n')

         hopNumber = 0
         for hop in row['hops']:
            hopNumber = hopNumber + 1
            hopIP     = ipaddress.ip_address(bytes(hop['hop']))
            status    = int(hop['status'])
            rtt       = int(hop['rtt'])
            sys.stdout.write('\t ' +
               str(hopNumber)  + ' ' +
               hex(status)[2:] + ' ' +
               str(rtt)        + ' ' +
               str(hopIP)      + '\n')


 # ====== Error =============================================================
else:
   sys.stderr.write('Specify query type!\n')
   sys.exit(1)
