/***********************************************************************
/ Copyright (c) 2001, Nishan Systems, Inc.
/ All rights reserved.
/ 
/ Redistribution and use in source and binary forms, with or without 
/ modification, are permitted provided that the following conditions are 
/ met:
/ 
/ - Redistributions of source code must retain the above copyright notice, 
/   this list of conditions and the following disclaimer. 
/ 
/ - Redistributions in binary form must reproduce the above copyright 
/   notice, this list of conditions and the following disclaimer in the 
/   documentation and/or other materials provided with the distribution. 
/ 
/ - Neither the name of the Nishan Systems, Inc. nor the names of its 
/   contributors may be used to endorse or promote products derived from 
/   this software without specific prior written permission. 
/ 
/ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
/ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
/ IMPLIED WARRANTIES OF MERCHANTABILITY, NON-INFRINGEMENT AND FITNESS FOR A 
/ PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NISHAN SYSTEMS, INC. 
/ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 
/ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 
/ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 
/ OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 
/ WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 
/ OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 
/ ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/
/***********************************************************************/

#include "isns.h"

#ifndef SNS_LINUX
#include <io.h>
#include <winsock2.h>
#include <Ws2tcpip.h>
#else
#define Sleep(a) sleep(a)
#endif

#include <time.h>
#include <signal.h>
#include <sys/poll.h>
#include "util.h"
#include "comm.h"
#include "parse.h"

#define CONNECTION_TIMEOUT 10 

MSG_CB msg_q[MSG_Q_SIZE]; /* TCP msg Q */
int msgInitFlag = FALSE;
int msgCurrentIndex = 0;
int msgCurrentHead = 0;
int suppress_no_such_entry_err = 0;

#ifndef SNS_LINUX
WSADATA g_WsaData;
#endif

SOCKET fd;

struct sockaddr_in their_addr;
struct sockaddr_in my_addr;


int tcpFlag;   /* TRUE if using TCP */
int isns_port; /* iSNS port */


extern char *optarg;
extern int enableESIFlag;
extern char p_ip[256];
char multicast_addr[20];

void
open_tcp(void)
{
   if (tcpFlag)
   {
      /* TCP Mode */
      if ((fd = socket (AF_INET, SOCK_STREAM, 0)) < 0)
      {
         perror ("Fatal Error while calling socket");
         exit(-1);
      }

      /* Setup Variables, Addresses, Etc. */
      their_addr.sin_family = AF_INET; /* host byte order */
      their_addr.sin_port = htons ((short) isns_port);   /* short, network byte order */
      their_addr.sin_addr.s_addr = inet_addr (p_ip);

      /* Set an alarm for timeout */
      alarm(CONNECTION_TIMEOUT);

      /* Do a Connect */
      //printf ("Connecting to %s...\n", p_ip);
      if (connect (fd, (struct sockaddr *) &their_addr, sizeof (their_addr))
          < 0)
      {
         printf ("Fatal Error: Connect.\n");
         exit(-1);
      }
      /* Cancel alarm */
      alarm(0);
   }
}

void
close_tcp(void)
{
   if (tcpFlag)
   {
      close(fd);
   }
}

/***********************************************************************/
/* Initializes communications */
/***********************************************************************/
int
InitComm( int hb_flag, /* Set to non-zero if using heartbeat to find iSNS */
          int l3_hb_flag,
          int t_flag  /* Set to non-zero if using TCP */)
{
   tcpFlag=t_flag;
#ifndef SNS_LINUX
   /* Start up the winsock proprietary Stuff */
   if (WSAStartup (MAKEWORD (0x02, 0x00), &g_WsaData) == SOCKET_ERROR)
   {
      exit(-1);
   }
#endif

   if (hb_flag)
   {
      if (-1==HeartBeatListener())
         exit(-1);
   }
   else if (l3_hb_flag)
   {
      if (-1==L3_HeartBeatListener())
          exit(-1);
   }

   if (tcpFlag)
   {
#if 0
      int tmp;

      /* TCP Mode */
      if ((fd = socket (AF_INET, SOCK_STREAM, 0)) < 0)
      {
         perror ("Fatal Error while calling socket");
         exit(-1);
      }

      if (-1 == setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &tmp, 4)) {
         perror ("setsockopt");
         exit(-1);
      }

      /* Setup Variables, Addresses, Etc. */
      their_addr.sin_family = AF_INET; /* host byte order */
      their_addr.sin_port = htons ((short) isns_port);   /* short, network byte order */
      their_addr.sin_addr.s_addr = inet_addr (p_ip);

      /* Do a Connect */
      printf ("Connecting to %s...\n", p_ip);
      if (connect (fd, (struct sockaddr *) &their_addr, sizeof (their_addr))
          < 0)
      {
         printf ("Fatal Error: Connect.\n");
         exit(-1);
      }

      /* Spawn TCP Recv Thread*/
      {
#ifdef SNS_LINUX
         pthread_t junk;
         if (0 != pthread_create (&junk, NULL, (void*) (TCPReceiveMain) , NULL))
            printf ("\n\nThread Creation Failed!\n\n");
#else
         DWORD junk;
         CreateThread ((LPSECURITY_ATTRIBUTES)0, 0, (LPTHREAD_START_ROUTINE) TCPReceiveMain, NULL, 0, &junk);
         Sleep (10);
#endif
      }
#endif
   }
   else
   {
      /* UDP Mode */
      if ((fd = socket (AF_INET, SOCK_DGRAM, 0)) < 0)
      {
         perror ("Fatal Error while calling socket");
         exit(-1);
      }

      /* setup for the bind */
      my_addr.sin_family = AF_INET;
      my_addr.sin_port = htons ((short) 0);
      my_addr.sin_addr.s_addr = INADDR_ANY;

      /* Do a Bind */
      if (bind (fd, (struct sockaddr *) &my_addr, sizeof (my_addr)) < 0)
      {
         printf ("Fatal Error while Binding.\n");
         exit(-1);
      }

      /* Setup Variables, Addresses, Etc. */
      their_addr.sin_family = AF_INET; /* host byte order */
      their_addr.sin_port = htons ((short) isns_port);   /* short, network byte order */
      their_addr.sin_addr.s_addr = inet_addr (p_ip);
   }

   return (0);
}

/***********************************************************************/
/* Used to store PDU which comes on a TCP session */
/***********************************************************************/
int
AddMsg (void *p_msg, int size)
{
   int i;
   if (FALSE == msgInitFlag)
   {
      for (i = 0; i < MSG_Q_SIZE; i++)
      {
         msg_q[i].p_msg = NULL;
      }

      msgCurrentIndex = 0;
      msgCurrentHead = 0;
      msgInitFlag = TRUE;
   }
   msg_q[msgCurrentIndex].p_msg = (void *)malloc (size);
   msg_q[msgCurrentIndex].size = size;
   memcpy (msg_q[msgCurrentIndex].p_msg, p_msg, size);

   msgCurrentIndex = (msgCurrentIndex + 1) % MSG_Q_SIZE;

   return (0);
}
/***********************************************************************/
/* Sends a msg to the iSNS without waiting for a response */
/***********************************************************************/
int
ISNSJustSendCmd (ISNS_CMD * cmd)
{
   int e;
   int len = cmd->hdr.len + sizeof (ISNS_HDR);

   open_tcp();

   cmd->hdr.len = htons (cmd->hdr.len);
   e = SendPDU (cmd, len);
   close_tcp();
//   printf("PDU sent-->\n");
   DumpHex (cmd, e);
   if (e < 0)
   {
      printf ("Error Sending.\n");
      return (-1);
   }

   return (0);
}

/***********************************************************************/
/* Called to receive a PDU from the iSNS */
/***********************************************************************/
int
RcvPDU (ISNS_CMD * cmd, int size)
{
   int e;
   int rcvSize;
   int len;

   if (tcpFlag)
   {
#if 0
      /* Using TCP */
      rcvSize = 0;
      while ((e = TCPGetMsg (cmd, size)) == -1)
      {
         /* Loops until we receive a msg */
         Sleep (1);
      }
#else
      e = get_message((char *) cmd, size);
#endif
   }
   else
   {
      /* Using UDP */
      len = sizeof (their_addr);
      e = recvfrom (fd, (char *)cmd, size, 0, (struct sockaddr *) &their_addr, &len);

      if (e < 0)
      {
         printf ("Error Receiving.\n");
         return (e);
      }
   }

   return (e);
}

/***********************************************************************/
/* Sends a PDU.  This function will actually call send() or sento(). */
/***********************************************************************/
int
SendPDU (ISNS_CMD * cmd, int len)
{
   int e, n;
   if (tcpFlag)
   {
      /* Using TCP */
      n = 0;
      do {
         e = send (fd, (char *)cmd + n, len - n, 0);
//printf("sent %d\n", e);
         if (e < 0)
         {
            printf ("Error Sending.\n");
            return (-1);
         }
         n += e;
      } while (n < len);
   }
   else
   {
      /* Setup Variables, Addresses, Etc. */
      their_addr.sin_family = AF_INET; /* host byte order */
      their_addr.sin_port = htons ((short) isns_port);   /* short, network byte order */
      their_addr.sin_addr.s_addr = inet_addr (p_ip);

      e = sendto (fd, (char *) cmd, len, 0,
                  (struct sockaddr *) &their_addr, sizeof (their_addr));

      if (e < 0)
      {
         printf ("Error Sending.\n");
         return (e);
      }
   }

   return (e);
}

/***********************************************************************/
/* This will send a PDU and wait for a rsp.  The rsp will be ignored.  */
/***********************************************************************/
int
ISNSSendCmd (ISNS_CMD * cmd)
{
   int e;
   char buffer[1500];
   int len = cmd->hdr.len + sizeof (ISNS_HDR);
   int errorCode;

   open_tcp();

   cmd->hdr.len = htons (cmd->hdr.len);
   e = SendPDU (cmd, len);
   if (e < 0)
   {
      printf ("Error Sending.\n");
      close_tcp();
      return (e);
   }
//   printf("PDU sent-->\n");
   DumpHex (cmd, len);
   len = sizeof (their_addr);
   e = RcvPDU ((struct cmd *)buffer, sizeof (buffer));
   if (e < 0)
   {
      printf ("Error Receiving.\n");
   }
//   printf("PDU rcv-->\n");
   DumpHex (buffer, e);

   {
      char *ptr = (char *) buffer + sizeof (ISNS_HDR);
      errorCode = ntohl (*(uint32_t *) ptr);
   }
   if (errorCode != 0)
   {
      printf ("***WARNING: iSNS returned an error, error=%#x, \"%s\"\n", errorCode, errorText(errorCode));
   }
   else
   {
      ISNS_HDR *p_cmd;

      p_cmd = (struct isns_hdr *)buffer;

      /* Convert fields */
      p_cmd->flags = ntohs (p_cmd->flags);
      p_cmd->func_id = ntohs (p_cmd->func_id);
      p_cmd->len = ntohs (p_cmd->len);
      p_cmd->seq = ntohs (p_cmd->seq);
      p_cmd->version = ntohs (p_cmd->version);
      p_cmd->xid = ntohs (p_cmd->xid);
   }

   close_tcp();

   return (errorCode);
}

/***********************************************************************/
/* This will send a PDU and wait for a rsp.  The rsp will be returned. */
/***********************************************************************/
int
ISNSSendCmd2 (ISNS_CMD * cmd, char *rcvBuffer, int rcvSize)
{
   int e;
   int len = cmd->hdr.len + sizeof (ISNS_HDR);
   int errorCode;

   open_tcp();

   cmd->hdr.len = htons (cmd->hdr.len);
   e = SendPDU (cmd, len);
   if (e < 0)
   {
      printf ("Error Sending.\n");
   }
//   printf("PDU sent-->\n");
   DumpHex (cmd, len);
   len = sizeof (their_addr);
   e = RcvPDU ((struct cmd *)rcvBuffer, rcvSize);

   close_tcp();

   if (e < 0)
   {
      printf ("Error Receiving.\n");
   }
// printf("PDU rcv-->\n");
   DumpHex (rcvBuffer, e);

   {
      char *ptr = (char *) rcvBuffer + sizeof (ISNS_HDR);
      errorCode = ntohl (*(uint32_t *) ptr);
   }
   if (errorCode != 0)
   {
      /* ISNS_NO_SUCH_ENTRY_ERR is expected when unregistered members of a
         Discovery Domain is queried. Supress the error display for that case */
      if (!(suppress_no_such_entry_err && 
           (errorCode == ISNS_NO_SUCH_ENTRY_ERR)))
         printf ("***WARNING: iSNS returned an error, error=%#x, \"%s\"\n",
                 errorCode, errorText(errorCode));
   }
   else
   {
      ISNS_HDR *p_cmd;

      p_cmd = (struct isns_hdr *)rcvBuffer;

      /* Convert fields */
      p_cmd->flags = ntohs (p_cmd->flags);
      p_cmd->func_id = ntohs (p_cmd->func_id);
      p_cmd->len = ntohs (p_cmd->len);
      p_cmd->seq = ntohs (p_cmd->seq);
      p_cmd->version = ntohs (p_cmd->version);
      p_cmd->xid = ntohs (p_cmd->xid);
   }

   return (errorCode);
}

/***********************************************************************/
/* This will listen for a heartbeat and using the heartbeat message to
   initialize some variables. */
/***********************************************************************/
int
HeartBeatListener (void)
{
   SNS_Hb *hb_ptr;
   SOCKET fd;
   struct sockaddr_in their_addr;
   struct sockaddr_in my_addr;
   ISNS_CMD cmd;
   int e;
   int len;

   if ((fd = socket (AF_INET, SOCK_DGRAM, 0)) < 0)
   {
      perror ("Fatal Error while calling socket");
      return (-1);
   }

   /* setup for the bind */
   my_addr.sin_family = AF_INET;
   my_addr.sin_port = htons ((short) ISNS_HEARTBEAT_PORT);
   my_addr.sin_addr.s_addr = INADDR_ANY;

   /* Do a Bind */
   if (bind (fd, (struct sockaddr *) &my_addr, sizeof (my_addr)) < 0)
   {
      printf ("Fatal Error while Binding.\n");
      return (-1);
   }


   len = sizeof (their_addr);

   printf ("Waiting for a heartbeat...\n");
   while (1)
   {
      len = sizeof (their_addr);
      e = recvfrom (fd, (char *)&cmd, sizeof (cmd) - sizeof (ISNS_HDR), 0,
                    (struct sockaddr *) &their_addr, &len);

      if (e < 0)
      {
         printf ("***ERROR: recvfrom().\n");
      }
      if (ntohs (cmd.hdr.func_id) != ISNS_HEART_BEAT)
         continue;

      printf ("RCVd: Heart beat-->");
      DumpHex (&cmd, e);

      hb_ptr = (struct sns_hb_payload *)((char *) &cmd + sizeof (ISNS_HDR));

      hb_ptr->counter = ntohl (hb_ptr->counter);
      hb_ptr->interval = ntohl (hb_ptr->interval);
      hb_ptr->tcp_port = ntohs (hb_ptr->tcp_port);
      hb_ptr->udp_port = ntohs (hb_ptr->udp_port);

      isns_port = hb_ptr->udp_port;

      {
         struct in_addr ip;
         ip.s_addr = *(uint32_t *) ((char *) hb_ptr->ip_ptr + 12);
         printf ("SNS IP: %s.\n", inet_ntoa (ip));
         strcpy (p_ip, inet_ntoa (ip));
      }
      printf ("Heartbeat counter: %u.\n", hb_ptr->counter);
      printf ("Heartbeat interval: %u.\n", hb_ptr->interval);
      printf ("Heartbeat tcp_port: %u.\n", hb_ptr->tcp_port);
      printf ("Heartbeat udp_port: %u.\n", hb_ptr->udp_port);

      break;
   }

   return (0);
}

int
L3_HeartBeatListener (void)
{
   SNS_Hb *hb_ptr;
   SOCKET fd;
   struct sockaddr_in their_addr;
   struct sockaddr_in my_addr;
   ISNS_CMD cmd;
   int e;
   int len;
   struct ip_mreq stMreq;
   int iRet;

   if ((fd = socket (AF_INET, SOCK_DGRAM, 0)) < 0)
   {
      perror ("Fatal Error while calling socket");
      return (-1);
   }

   /* setup for the bind */
   my_addr.sin_family = AF_INET;
   my_addr.sin_port = htons ((short) ISNS_HEARTBEAT_PORT);
   my_addr.sin_addr.s_addr = INADDR_ANY;

   /* Do a Bind */
   if (bind (fd, (struct sockaddr *) &my_addr, sizeof (my_addr)) < 0)
   {
      printf ("Fatal Error while Binding.\n");
      return (-1);
   }

   /* join the multicast group */
   stMreq.imr_multiaddr.s_addr = inet_addr(multicast_addr);
   stMreq.imr_interface.s_addr = INADDR_ANY;

   iRet = setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, 
     (char *)&stMreq, sizeof(stMreq));

   if (iRet < 0) {
     printf ("setsockopt() IP_ADD_MEMBERSHIP failed, Err: %d\n",
#ifdef SNS_LINUX
	0);
#else
        WSAGetLastError());
#endif
  } 

   len = sizeof (their_addr);

   printf ("Waiting for a heartbeat...\n");
   while (1)
   {
      e = recvfrom (fd, (char *)&cmd, sizeof (cmd) - sizeof (ISNS_HDR), 0,
                    (struct sockaddr *) &their_addr, &len);

      if (e < 0)
      {
         printf ("***ERROR: recvfrom().\n");
      }
      if (ntohs (cmd.hdr.func_id) != ISNS_HEART_BEAT)
         continue;

      printf ("RCVd: Heart beat-->");
      DumpHex (&cmd, e);

      hb_ptr = (struct sns_hb_payload *)((char *) &cmd + sizeof (ISNS_HDR));

      hb_ptr->counter = ntohl (hb_ptr->counter);
      hb_ptr->interval = ntohl (hb_ptr->interval);
      hb_ptr->tcp_port = ntohs (hb_ptr->tcp_port);
      hb_ptr->udp_port = ntohs (hb_ptr->udp_port);

      isns_port = hb_ptr->udp_port;

      {
         struct in_addr ip;
         ip.s_addr = *(uint32_t *) ((char *) hb_ptr->ip_ptr + 12);
         printf ("SNS IP: %s.\n", inet_ntoa (ip));
         strcpy (p_ip, inet_ntoa (ip));
      }
      printf ("Heartbeat counter: %u.\n", hb_ptr->counter);
      printf ("Heartbeat interval: %u.\n", hb_ptr->interval);
      printf ("Heartbeat tcp_port: %u.\n", hb_ptr->tcp_port);
      printf ("Heartbeat udp_port: %u.\n", hb_ptr->udp_port);

      break;
   }

   return (0);
}

/***********************************************************************/
/* Thread used to receive ALL TCP messages */
/***********************************************************************/
#ifdef SNS_LINUX
int
#else
DWORD WINAPI
#endif
TCPReceiveMain (void *lparam)
{
   int e;
   char buffer[2048];
   ISNS_CMD *p_cmd;
   int msg_size;
   int pending;
   int startIndex;
   int endIndex;

   startIndex = 0;
   endIndex = 0;
   while (1)
   {
      e = recv (fd, &buffer[endIndex], sizeof (buffer), 0);
      if (e < 0)
      {
         printf ("***ERROR: recv().\n");
         exit(-1);
      }
      endIndex += e;
      pending = endIndex - startIndex;

      /* Check to see if we received all the message. */
      while (pending > sizeof (ISNS_HDR))
      {
         p_cmd = (struct cmd *)&buffer[startIndex];
         msg_size = ntohs (p_cmd->hdr.len) + sizeof (ISNS_HDR);

         if (pending < msg_size)
            break;

         switch (ntohs (p_cmd->hdr.func_id))
         {
         case ISNS_ESI:
            printf ("Rcv ESI via TCP-->\n");
            DumpHex (p_cmd, msg_size);
            Send_ESI_Response (p_cmd, msg_size);
            break;
         case ISNS_SCN:
            printf ("Rcv SCN via TCP-->\n");
            Process_SCN (p_cmd, msg_size);
            break;
         default:
            AddMsg (p_cmd, msg_size);
            break;
         }
         startIndex += msg_size;
         pending -= msg_size;
      }

      if (pending == 0)
      {
         startIndex = 0;
         endIndex = 0;
      }
   }
}
/***********************************************************************/
/* Used to retrieve messages from the TCP message Q. */
/***********************************************************************/
int
TCPGetMsg (void *buffer, int b_size)
{
   int size;
   if (NULL == (msg_q[msgCurrentHead].p_msg))
      return (-1);

   size = msg_q[msgCurrentHead].size;
   memcpy (buffer, msg_q[msgCurrentHead].p_msg, msg_q[msgCurrentHead].size);
   free (msg_q[msgCurrentHead].p_msg);
   msg_q[msgCurrentHead].p_msg = NULL;

   msgCurrentHead = (msgCurrentHead + 1) % MSG_Q_SIZE;

   return (size);

}

#define TIMEOUT 3000 // 3 seconds

// Returns length of payload or -1 on error.

int
get_message(unsigned char *buf, int buf_len)
{
	int n, pdu_len, total_len;

	total_len = 0;

	do {
		if (get_pdu_header(fd, buf) != 12)
			return -1;

		pdu_len = 256 * buf[4] + buf[5];

		if (12 + total_len + pdu_len > buf_len)
			return -1;

		n = get_pdu_payload(fd, buf + 12 + total_len, pdu_len);

		if (n != pdu_len)
			return -1;

		total_len += pdu_len;

	} while ((buf[6] & 0x08) == 0); // while "last pdu" bit is 0

	// put total length in header (network order)

	buf[4] = total_len / 256;
	buf[5] = total_len;

	return total_len;
}

// Get the header part of the iSNS message.
// Returns the number of bytes read.

int
get_pdu_header(int fd, unsigned char *buf)
{
	int n = 0;
	struct pollfd pollfd;

	pollfd.fd = fd;
	pollfd.events = POLLIN;

	while (n < 12) {
		if (poll(&pollfd, 1, TIMEOUT) < 1)
			break;
		n += recv(fd, buf + n, 12 - n, 0);
	}

	return n;
}

// Get the payload part of the iSNS message.
// Returns the number of bytes read.

int
get_pdu_payload(int fd, unsigned char *payload, int len)
{
	int n = 0;
	struct pollfd pollfd;

	pollfd.fd = fd;
	pollfd.events = POLLIN;

	while (n < len) {
		if (poll(&pollfd, 1, TIMEOUT) < 1)
			break;
		n += recv(fd, payload + n, len - n, 0);
	}

	return n;
}
