/*
 * Copyright (c) 2005 SBEi, Inc.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version
 * 2 of the License, or (at your option) any later version.
 */

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <errno.h>
#include <string.h>
#include <netdb.h>
#include <sys/types.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/poll.h>
#include <dirent.h>
#include <signal.h>
int trace = 0;
extern int isns_port;
extern int alias_set;
extern char p_ip[256];
extern char alias[256];
static unsigned char pdu_header[12];
#define BUF_LEN 10000
static unsigned char buf[BUF_LEN], *bufp;
static char initiator_name[1000], entity_id[1000];
static int get_message(int);
static int get_initiator_name(void);
static int get_entity_id(char *);
static int send_and_recv(void);
static int deregister_entity(char *);
static void dump_tlv(int);
static int get_pdu_header(int);
static int get_pdu_payload(int, unsigned char *, int);
static int get_initiators(void);
int get_initiators_full(void);
int get_network_portals(void);
int register_initiator(void);
int sync_portals(void);

static void emit_hdr(int);
static void emit_delimiter(void);
static void emit_entity_identifier(char *);
static void emit_entity_protocol(int);
static void emit_iscsi_name(char *);
static void emit_iscsi_node_type(int);
static void emit_iscsi_alias(char *);
static void emit_pg_iscsi_name(char *);
static void emit_addr(int, unsigned int);
static void emit_port(int, unsigned int);
static void emit_pg_tag(int);
static void emit_null_addr(void);
static void emit_null_port(void);

#include <isns.h>

#define DevAttrReg 1
#define DevAttrQry 2
#define DevDereg 4

#define ENTITY_PROTOCOL_ISCSI 2

#define NODE_TYPE_CONTROL 4
#define NODE_TYPE_INITIATOR 2
#define NODE_TYPE_TARGET 1

#define MAX_RETRY_COUNT	2

#define CONNECTION_TIMEOUT 10

#if 0
#define PRINT printf
#else
#define PRINT
#endif


static void
emit_iscsi_ascii_tag(int tag, char *name)
{               
        int m, n;

	if (trace)
		printf("Tag: %d=%s\n", tag, name);
        
        *bufp++ = 0;    // tag
        *bufp++ = 0;
        *bufp++ = 0;
        *bufp++ = tag;
                
        if (name == NULL) {
                *bufp++ = 0;
                *bufp++ = 0;
                *bufp++ = 0;
                *bufp++ = 0;
                return;
        }

        n = strlen(name) + 1;
        m = (n + 3) & ~3;
        
        *bufp++ = 0;    // length
        *bufp++ = 0;
        *bufp++ = 0;
        *bufp++ = m;

        strcpy(bufp, name);
        bufp += n;
        while (n++ < m)
                *bufp++ = 0;
}

#define emit_entity_identifier(name) emit_iscsi_ascii_tag(1, name)
#define emit_iscsi_name(name) emit_iscsi_ascii_tag(32, name)
#define emit_null_type() emit_iscsi_ascii_tag(33, NULL);
#define emit_iscsi_alias(name) emit_iscsi_ascii_tag(34, name)
#define emit_null_iscsi_alias() emit_iscsi_ascii_tag(34, NULL)
//#define emit_iscsi_auth(name) emit_iscsi_ascii_tag(42, name)
#define emit_iscsi_auth(name)

int
initiator_status(void)
{
	int retry_count = 0;

	if (get_initiator_name() < 0)
		return(-1);
	
	daemon(1, 0);
	
	while (1) {
		PRINT("Getting portals for %s\n", initiator_name);
		if (get_network_portals() < 0) {
			PRINT("get_network_portals() exception\n");

			if (retry_count++ > MAX_RETRY_COUNT) {
				sleep(10);
				retry_count = 0;
			}

			if (get_entity_id(initiator_name) == 0)
				continue;

			PRINT("Deregistering initiator node\n");
			deregister_initiator();
			PRINT("Registering initiator node %s\n", initiator_name);
			register_initiator();
		}
		
		sleep(10);
	}

	return(0);
}

static void devattrqry_iscsi_nodes (unsigned char);

int
get_network_portals(void)
{
	unsigned char *p;
	FILE *f;
	int n, tag, len;
	char tmp[512];

	devattrqry_iscsi_nodes(NODE_TYPE_TARGET);
	emit_null_iscsi_alias();
	emit_null_addr();
	emit_null_port();

	if ((n = send_and_recv()) < 0)
		return(-1);

	memset(tmp, 0, 512);
	sprintf(tmp, "/var/spool/isns/network_portals-%s", p_ip);

	if (!(f = fopen(tmp, "w")))
		return(-1);
	
	p = buf + 4; // skip over status

	while (p < buf + n) {
		tag = ntohl(*((int *) p));
		p += 4;
		len = ntohl(*((int *) p));
		p += 4;

		if (len && (p + len) <= (buf + n)) {
			if (tag == 32) {
				fprintf(f, "---------------------------------\n");
				fprintf(f, "iSCSI ID  : %s\n", p);
				fprintf(f, "Type      : Target\n");
			}
			if (tag == 34) {
				fprintf(f, "Alias     : %s\n", p);
			}
			if (tag == 16) {
				struct in_addr in;
				in.s_addr = *((unsigned int *)(p + 12));
				fprintf(f, "Portal IP    : %s\n", inet_ntoa(in));
			}
			if (tag == 17) {
				fprintf(f, "Portal Port   : %d\n", ntohl(*((int *) p)));
			}
		}

		p += len;
	}

	fclose(f);
	return(0);
}

static void devattrqry_iscsi_nodes (unsigned char type)
{
	emit_hdr(DevAttrQry);

	if (type == NODE_TYPE_TARGET) {
		if (get_initiator_name() < 0)
			return;

		emit_iscsi_name(initiator_name);
	} else {
		printf("Unknown type: %s\n", type);
		return;
	}
	
	emit_iscsi_node_type(type);
	emit_delimiter();
	emit_iscsi_name(NULL);
}

int
register_initiator(void)
{
	int ret;
	
	if (get_initiator_name() < 0)
		return(-1);
	
	emit_hdr(DevAttrReg);
	emit_iscsi_name(initiator_name);
	emit_delimiter();
	emit_entity_identifier(NULL);
	emit_entity_protocol(ENTITY_PROTOCOL_ISCSI);
	emit_iscsi_name(initiator_name);
	emit_iscsi_node_type(NODE_TYPE_INITIATOR);
	if (alias_set)
		emit_iscsi_alias(alias);
	
	if ((ret = send_and_recv()) < 0)
		return(ret);

	return(0);
}

static int
filter(const struct dirent *p)
{
	if (strncmp(p->d_name, "tpg_", 4) == 0)
		return 1;
	else
		return 0;
}

// Use Entity ID instead of iSCSI Node Name.
// Deregistering using just the node name removes the name but not the portals.
// Then a subsequent registration with the same portals fails.
extern int
deregister_initiator(void)
{
	int len, n, sockfd, tmp;
	unsigned char *p;

	if (get_initiator_name() < 0)
		return(-1);

	if (get_entity_id(initiator_name) < 0)
		return(-1);

	if (deregister_entity(initiator_name) < 0)
		return(-1);

	return(0);
}

static int
deregister_entity(char *src)
{
	if (trace)
		printf("deregister_entity()\n");
	emit_hdr(DevDereg);
	emit_iscsi_name(src);
	emit_delimiter();
	emit_entity_identifier(entity_id);

	if (send_and_recv() < 0)
		return(-1);

	return(0);
}

// "An iSNSP message may be sent in one or more iSNS Protocol Data Units."

#define TIMEOUT 1000

// Get an iSNS message.
// Returns total length of payload.

int
get_message(int sockfd)
{
	int n, pdu_len, total_len;

	total_len = 0;

	do {
		if (get_pdu_header(sockfd) != 12)
			return 0;

		pdu_len = 256 * pdu_header[4] + pdu_header[5];

		if (total_len + pdu_len > BUF_LEN)
			return 0;

		n = get_pdu_payload(sockfd, buf + total_len, pdu_len);

		if (n != pdu_len)
			return 0;

		total_len += pdu_len;

	} while ((pdu_header[6] & 0x08) == 0); // while "last pdu" bit is 0

	return total_len;
}

// Get the header part of the iSNS message.
// Returns the number of bytes read.

static int
get_pdu_header(int sockfd)
{
	int n = 0;
        int count;
	struct pollfd pollfd;

	pollfd.fd = sockfd;
	pollfd.events = POLLIN;

	while (n < 12) {
		if (poll(&pollfd, 1, TIMEOUT) < 1)
			break;
		count = recv(sockfd, pdu_header + n, 12 - n, 0);

                if (count < 1)
			break;                     

		n += count;
	}

	return n;
}

// Get the payload part of the iSNS message.
// Returns the number of bytes read.

static int
get_pdu_payload(int sockfd, unsigned char *payload, int len)
{
	int n = 0;
        int count;
	struct pollfd pollfd;

	pollfd.fd = sockfd;
	pollfd.events = POLLIN;

	while (n < len) {
		if (poll(&pollfd, 1, TIMEOUT) < 1)
			break;
		count = recv(sockfd, payload + n, len - n, 0);

                if (count < 1)
			break;                     

		n += count;
	}

	return n;
}

static int
get_initiator_name(void)
{
	FILE *f;

	if (!(f = fopen(SYSFS_INITIATOR_NODENAME, "r"))) {
		if (!(f = fopen(PROCFS_INITIATOR_NODENAME, "r"))) {
			printf("Unable to locate initiator_nodename\n");
			return(-1);
		}
	}
	fscanf(f, "iSCSI InitiatorName: %s", initiator_name);
	fclose(f);

	return(0);
}

static int
get_entity_id(char *src)
{
	int len, n, tag;
	unsigned char *p;

	if (trace)
		printf("get_entity_id()\n");

	emit_hdr(DevAttrQry);
	emit_iscsi_name(src);
	emit_iscsi_name(src);	// Use message key, otherwise we get
					// entity IDs for every node in the
					// discovery domain.
	emit_delimiter();
	emit_entity_identifier(NULL);
	emit_iscsi_name(src);

	n = send_and_recv();

	if (n == -1)
		return(-1);

	// dump_tlv(n);

	// find the entity id

	p = buf + 4; // skip over status

	while (p < buf + n) {
		tag = ntohl(*((int *) p));
		p += 4;
		len = ntohl(*((int *) p));
		p += 4;
		if (tag == 1) {
			strcpy(entity_id, p);
			if (trace)
				printf("entity id = %s\n", entity_id);
			return(0);
		}
		p += len;
	}

	return(-1);
}

// Returns length of received payload or -1 on error.

static int
send_and_recv(void)
{
	int len, n, sockfd, status, tmp;

	union {
		struct sockaddr sockaddr;
		struct sockaddr_in sockaddr_in;
	} u;

	if (trace)
		printf("   send_and_recv()\n");

	// length

	len = bufp - buf;
	*((short *) (buf + 4)) = htons((short) (len - 12));

	// open socket

	sockfd = socket(AF_INET, SOCK_STREAM, 0);

	if (sockfd == -1) {
		if (trace)
			printf("   socket() errno = %d (%s)\n", errno, strerror(errno));
		return -1;
	}

	u.sockaddr_in.sin_family = AF_INET;
	u.sockaddr_in.sin_port = htons((short) isns_port);
	inet_aton(p_ip, &u.sockaddr_in.sin_addr);

	/* Set an alarm for timeout */
	alarm(CONNECTION_TIMEOUT);

	tmp = connect(sockfd, &u.sockaddr, sizeof (struct sockaddr));

	/* Cancel alarm */
	alarm(0);

	if (tmp == -1) {
		if (trace)
			printf("   connect() errno = %d (%s)\n", errno, strerror(errno));
		return -1;
	}

	n = send(sockfd, buf, len, 0);

	if (trace)
		printf("sent %d bytes\n", n);

	n = get_message(sockfd);

	if (trace)
		printf("received %d payload bytes\n", n);

	close(sockfd);

	if (n < 4)
		return -1;

	status = ntohl(*((int *) buf));

	if (trace)
		printf("status = %d\n", status);

	if (status)
		return -1;
	else
		return n;
}

static void
dump_tlv(int n)
{
	int len, tag;
	unsigned char *p;
	if (trace)
		printf("dump_tlv() payload length = %d\n", n);
	p = buf + 4; // skip over status
	while (p < buf + n) {
		tag = ntohl(*((unsigned int *) p));
		p += 4;
		len = ntohl(*((unsigned int *) p));
		p += 4;
		printf("%d %d\n", tag, len);
		p += len;
	}
}

static void
emit_hdr(int func_id)
{
	if (trace)
		printf("   emit_hdr() func = %d\n", func_id);

	bufp = buf;

	// version

	*bufp++ = 0;
	*bufp++ = 1;

	// function

	*bufp++ = 0;
	*bufp++ = func_id;

	// pdu length (fill in later)

	*bufp++ = 0;
	*bufp++ = 0;

	// flags

	*bufp++ = 0x8c;
	*bufp++ = 0x00;

	// transaction id

	*bufp++ = 0;
	*bufp++ = 0;

	// sequence id

	*bufp++ = 0;
	*bufp++ = 0;
}

static void
emit_delimiter(void)
{
	int i;
	if (trace)
		printf(" 0 emit_delimiter()\n");
	for (i = 0; i < 8; i++)
		*bufp++ = 0;
}

static void
emit_entity_protocol(int proto)
{
	if (trace)
		printf(" 2 emit_entity_protocol() proto = %d\n", proto);

	*bufp++ = 0;	// tag
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 2;

	*bufp++ = 0;	// length
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 4;

	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = proto;
}

static void
emit_iscsi_node_type(int type)
{
	if (trace)
		printf("33 emit_iscsi_node_type() type = %d\n", type);

	*bufp++ = 0;	// tag
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 33;

	*bufp++ = 0;	// length
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 4;

	*bufp++ = 0;	// value
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = type;
}

static void
emit_pg_iscsi_name(char *name)
{
	int m, n;

	printf("48 emit_pg_iscsi_name() name = %s\n", name);

	n = strlen(name) + 1;

	m = (n + 3) & ~3;

	*bufp++ = 0;	// tag
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 48;

	*bufp++ = 0;	// length
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = m;

	strcpy(bufp, name);
	bufp += n;
	while (n++ < m)
		*bufp++ = 0;
}

// emit portal group tag

static void
emit_pg_tag(int tag)
{
	if (trace)
		printf("51 emit_pg_tag() tag = %d\n", tag);

	*bufp++ = 0;	// tag
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 51;

	*bufp++ = 0;	// length
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 4;

	*((int *) bufp) = htonl(tag);
	bufp += 4;
}

// emit portal group ip addr

static void
emit_addr(int pg, unsigned int addr)
{
	if (pg)
		pg = 49;
	else
		pg = 16;

	if (trace) {
		addr = htonl(addr);
		printf("%2d emit_addr() addr = %s\n", pg, inet_ntoa(*((struct in_addr *) &addr)));
		addr = ntohl(addr);
	}

	*bufp++ = 0;	// tag
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = pg;

	*bufp++ = 0;	// length
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 16;

	*bufp++ = 0;	// value
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 0xff;
	*bufp++ = 0xff;

	*((unsigned int *) bufp) = htonl(addr);
	bufp += 4;
}

static void
emit_port(int pg, unsigned int port)
{
	if (pg)
		pg = 50;
	else
		pg = 17;

	if (trace)
		printf("%2d emit_port() port = %d\n", pg, port);

	*bufp++ = 0;	// tag
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = pg;

	*bufp++ = 0;	// length
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 4;

	*((unsigned int *) bufp) = htonl(port);
	bufp += 4;
}

void
emit_null_addr(void)
{
	if (trace)
		printf("16 emit_null_addr()\n");

	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 16;

	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 0;
}

static void
emit_null_port(void)
{
	if (trace)
		printf("17 emit_null_port()\n");

	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 17;

	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 0;
	*bufp++ = 0;
}

int
check_initiator_registration_status(void)
{
	if (get_initiator_name() < 0)
		return(-1);

	if (get_entity_id(initiator_name) < 0)
		return(-1);

	return(0);
}
