/*
 * feedbackd - dynamic feedback system for LVS
 * Copyright (C) 2002 Jeremy Kerr
 *
 * This file is part of feedbackd.
 *
 *  feedbackd is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  feedbackd is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with feedbackd; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

/**
 * @file necp_listener.c
 * NECP listener
 *
 * given a socket FD, the listener will receive NECP packets and
 * process them accordingly.
 */


#include "necp_listener.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <errno.h>
#include <assert.h>
#include <signal.h>

#include <sys/select.h>

#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#include "necp.h"
#include "necp_handler.h"
#include "log.h"
#include "scheduler.h"
#include "linkedlist.h"

/**
 * The biggest packet we're willing to malloc for. (In payload units)
 */
#ifndef MAX_PAYLOAD_N
#define MAX_PAYLOAD_N 10
#endif

/**
 * Handles a new connection - accepts from the main socket, and creates a new
 * socket descriptor for the new connection. If there is space, a new entry in
 * the server table is added
 *
 * @param sd the socket descriptor to accept() from
 */
static void new_connection(int sd);

/**
 * Handles interrupts - passed as allback to signal() function
 *
 * @param signum The number of the signal caught
 */
void interrupt(int signum);

void necp_listen()
{
	int sd;
	int tmp;
	struct sockaddr_in addr;
	unsigned int length, maxfd;
	fd_set fdset;

	/* open main socket */
init:

	/* create socket */
	if((sd = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
		perror("socket");
		exit(1);
	}

	tmp = 1;
	if (setsockopt(sd, SOL_SOCKET, SO_REUSEADDR, (void *) &tmp,
				sizeof(tmp))) {
		perror("setsockopt");
		exit(1);
	}
#ifdef SO_BINDANY
	tmp = 1;
	if (setsockopt(sd, SOL_SOCKET, SO_BINDANY, (void *) &tmp,
				sizeof(tmp))) {
		perror("setsockopt");
		exit(1);
	}
#endif /* SO_BINDANY */

	/* specify where to bind to in a sockaddr_in structure */
	memset((char *)&addr, 0, sizeof(addr));
	addr.sin_family		= AF_INET;
	addr.sin_addr.s_addr	= htonl(INADDR_ANY);
	addr.sin_port		= htons(NECP_PORT);

	/* i call the big one bind()ey */
	length = sizeof(addr);
	if(bind(sd, (struct sockaddr *)&addr, length) < 0) {
		perror("bind");
		exit(-1);
	}

	if (signal(SIGINT, interrupt) == SIG_ERR) {
		perror("signal");
		exit(-1);
	}

	/* and listen */
	if (listen(sd, 2)) {
		perror("listen");
		exit(1);
	}

	/* establish the fdset for select() */
	FD_ZERO(&fdset);
	FD_SET(sd, &fdset);
	maxfd = sd;

	/* receive connections */
	while (necp_listener_running) {
		int n, t;
		struct timeval *tv;
		listitem *i;

		/* see if & when to interrupt the select to process a
		 * scheduled task */
		if ((t = get_timeout()) != 0) {
			if (!(tv = malloc(sizeof(struct timeval)))) {
				perror("malloc");
				exit(1);
			}
			tv->tv_usec = 0;
			tv->tv_sec  = t;
		} else {
			tv = NULL;
		}

		/* wait for data ... */
		log_printf(LOG_VDEBUG,
		    "Going into select(), timeout in %d",
		    tv ? (int)tv->tv_sec : 0);

		n = select(maxfd + 1, &fdset, NULL, NULL, tv);

		free(tv);

		if (n < 0) {
			if (errno == EINTR && !necp_listener_running) {
				log_printf(LOG_WARN, "Caught signal, exiting");
				continue;
			}
			perror("select");
			log_printf(LOG_ERR, "select() returned an error (%d). "
				     "reopening socket", errno);
			close(sd);
			goto init;
		}

		log_printf(LOG_VDEBUG, "select() returned, %d fds changed", n);

		/* handle new connections on the main socket */
		if (FD_ISSET(sd, &fdset)) {
			new_connection(sd);
		}

		/* check the servers list to see if data has arrived. */
		i = servers.head;

		while (i) {
			/* i may be deleted in the data_from_server function,
			 * so we need to find the next item here. */
			listitem *next = i->next;
			struct server *server = (struct server *)i->item;
			if (FD_ISSET(server->sd, &fdset))
				data_from_server(server);
			i = next;
		}

		/* process timeouts */
		if (get_timeout() <= 0) {
			process_tasklist();
		}

		/* reset fd set */
		FD_ZERO(&fdset);
		FD_SET(sd, &fdset);
		maxfd = sd;

		list_for_each(servers, i) {
			struct server *server = (struct server *)i->item;
			if (server->sd) {
				FD_SET(server->sd, &fdset);
				maxfd = maxfd < server->sd  ? server->sd
					                    : maxfd;
			}
		}

	}
	close(sd);
}

static void new_connection(int sd)
{
	struct sockaddr_in addr;
	int new_sd;
	unsigned int length;
	char *host;
	struct server *new_server;
	listitem *i;

	assert(sd);

	length = sizeof(addr);

	/* Accept the incoming connection */
	if ((new_sd = accept(sd, (struct sockaddr *)&addr, &length)) < 0) {
		perror("accept");
		exit(1);
	}
	host = inet_ntoa(addr.sin_addr);

	new_server = 0;

	list_for_each(servers, i) {
		struct server *server = (struct server *)(i->item);
		if (server->address.s_addr == addr.sin_addr.s_addr) {
			new_server = server;
			log_printf(LOG_DEBUG, "Found an exisiting server entry "
					"for new connection from %s, reusing",
				       host);
			break;
		}
	}

	if (!new_server) {
		if (!(new_server = malloc(sizeof(struct server)))) {
			perror("malloc");
			log_printf(LOG_WARN, "Connection from %s rejected, "
					"no connections available", host);
			close(new_sd);
			return;
		}
		list_add(&servers, new_server);
		log_printf(LOG_VDEBUG, "Servers list has %d items",
				servers.count);
		strncpy(new_server->name, host, 16);
	}

	new_server->sd = new_sd;
	new_server->address = addr.sin_addr;
	new_server->retries = 0;
	log_printf(LOG_INFO, "Connection from %s added to server table", host);

}

void data_from_server(struct server *serverp)
{
	int sd, rc;
	struct necp_header *header;
	struct necp_payload_unit *payload;
	char *buf;

	assert(serverp);

	sd = serverp->sd;

	/* read header into the buffer */
	if (!(buf = malloc(sizeof(*header)))) {
		perror("malloc");
		exit(1);
	}
	rc = read(sd, buf, sizeof(*header));

	log_printf(LOG_VDEBUG, "Read %d bytes from %s",
			rc, serverp->name);

	if (rc == sizeof(*header)) {

		/* create the header struct */
		header = (struct necp_header *)buf;
		prepare_header_in(header);

		/* check protocol identifier */
		if (header->protocol_identifier != NECP_PID) {
			log_printf(LOG_DEBUG,
			    "Invalid protocol identifier from %s (%4x)",
			    serverp->name,
			    header->protocol_identifier);
			/* clear out the rest of the data */
			while (read(sd, NULL, 1024) > 0);

		/* check protocol version */
		} else if (header->version != NECP_VERSION) {
			log_printf(LOG_DEBUG,
			    "Invalid NECP version received from %s (%1d)",
			    serverp->name, header->version);

			/* clear buffer */
			while (read(sd, NULL, 1024) > 0);

			/* send back a protocol version mismatch
			 * packet */
			header->flags  = NECP_F_Error;
			header->payload_len = 0;

			send_packet(serverp, (char *)header);

		/* discard packets with an unreasonable length */
		} else if (header->payload_len > MAX_PAYLOAD_N *
				sizeof(struct necp_payload_unit)) {
			log_printf(LOG_WARN, "Packet from %s was too large "
			              "(%d payload units, max is %d)",
				      serverp->name,
				      header->payload_len
				        / sizeof(struct necp_payload_unit),
				      MAX_PAYLOAD_N
				      );
			while (read(sd, NULL, 1024) > 0);

		/* any other checks ? */
		} else {
			/* header is OK */
			int b;

			/* create enough buffer space for the whole packet.
			 * should we trust the length field ??
			 * we could be requesting 4GB of memory here.
			 */
			if (!(buf = realloc(buf,
					sizeof(struct necp_payload_unit) +
					header->payload_len))) {
				perror("realloc");
				free(buf);
				return;
			}

			/* need to reassign the header, because buf may
			 * have moved. */
			header = (struct necp_header *) buf;

			/* read the packet into buf */
			if ((rc = read(sd, buf + sizeof(*header),
							header->payload_len))
			     != header->payload_len) {
				log_printf(LOG_DEBUG,
				    "Packet payload length (%d) does not "
				    "match length field (%d), ignoring",
				    rc, header->payload_len);
				free(buf);
				return;
			}

			/* prepare each payload */
			for (b = 0; b < header->payload_len;
			            b += sizeof(struct necp_payload_unit)) {
				payload = (struct necp_payload_unit *)
				          (buf + b + sizeof(*header));
			}

			/* pass to the necp handler */
			handle_necp_data(serverp, buf,
			                 header->payload_len +
			                 sizeof(*header));
		}


	} else if (rc <= 0) {
		/* the socket has closed */
		deactivate_server(serverp);
	} else {
		log_printf(LOG_DEBUG,
		    "Malformed header from %s, ignoring packet",
		    serverp->name);
	}

	free(buf);
}

void disconnect_server(struct server *server)
{
	if (server->sd) {
		/* close server fd */
		close(server->sd);
		server->sd = 0;
	}
}

int send_packet(struct server *server, char *packet)
{
	int packetsize;
	struct necp_header *header;

	assert(server);
	assert(packet);

	/* set the header pointer to the start of the packet */
	header = (struct necp_header *)packet;

	/* set global packet parameters */
	header->protocol_identifier = NECP_PID;
	header->version = NECP_VERSION;

	/* these will require changing once authentication is added */
	header->seq_num = 0;

	/* determine the number of bytes to send */
	packetsize = sizeof(*header) + header->payload_len;

	prepare_header_out(header);

	if (write(server->sd, packet, packetsize) != packetsize) {
		perror("write");
		return -1;
	}

	return 0;
}

void interrupt(int signum)
{
	necp_listener_running = 0;
}
