/*
 * feedbackd - dynamic feedback system for LVS
 * Copyright (C) 2002 Jeremy Kerr
 * 
 * This file is part of feedbackd.
 *
 *  feedbackd is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  feedbackd is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with feedbackd; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

/**
 * @file necp_handler.c
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <assert.h>

#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#include "necp_handler.h"
#include "necp_listener.h"
#include "necp.h"
#include "log.h"
#include "ipvs_interface.h"
#include "scheduler.h"

static struct virtual_service *find_virtual_service(uint8_t protocol,
		uint16_t port);

void handle_necp_data(struct server *server, char *packet, int size)
{
	struct necp_header *header;

	assert(server);
	assert(packet);
	assert(size >= NECP_PACKET_SIZE(0));

	/* set the header to the start address of the packet */
	header = (struct necp_header *) packet;

	switch (header->opcode) {
	case NECP_INIT:
		handle_necp_init(server, packet);
		break;

	case NECP_START:
		handle_necp_start(server, packet);
		break;
		
	case NECP_STOP:
		handle_necp_stop(server, packet);
		break;

	case NECP_KEEPALIVE_ACK:
		handle_necp_keepalive_ack(server, packet);
		break;

	case NECP_NOOP:
		break;

	default:
		log_printf(LOG_DEBUG, "Got unknown packet (opcode: %d) from %s",
		                header->opcode, server->name);
	}
		
}	

int handle_necp_init(struct server *server, char *packet)
{
	struct necp_header *header;
	struct necp_payload_unit *payload;

	assert(packet);

	header = (struct necp_header *)packet;

	log_printf(LOG_DEBUG, "Got NECP_INIT from %s", server->name);

	/* NECP_INIT always has one payload unit */
	if (header->payload_len != sizeof(*payload)) {
		log_printf(LOG_WARN, "NECP_INIT from %s was incorrect size, "
				"ignoring");
		/* send error? */
		return -1;
	}

	payload = (struct necp_payload_unit *)(packet + sizeof(*header));

	if (ntohl(payload->data0) & 0x01) {
		/* server wants authentication. set flags in
		 * server entry table */
		log_printf(LOG_ERR, "Client wants authentication, but this is "
				"not supported in this version of "
				"feedbackd-master");
	} else {
		payload->data0 = payload->data1 = payload->data2 = 0;
	}

	header->opcode = NECP_INIT_ACK;
			
	send_packet(server, packet);

	return 0;
}

int handle_necp_start(struct server *server, char *packet)
{
	struct necp_header *header;
	struct necp_payload_unit *payload, *last_payload;

	assert(packet);

	header = (struct necp_header *)packet;

	assert(header->opcode == NECP_START);

	log_printf(LOG_DEBUG, "Got NECP_START from %s", server->name);

	payload = (struct necp_payload_unit *)(packet + sizeof(*header));

	while (((int)payload - (int)packet) < header->payload_len) {
		uint8_t protocol;
		uint16_t port, forwarding;
		struct service *service;
		
		forwarding = ntohl(payload->data0) & 0xFFFF;
		protocol   = ntohl(payload->data1) & 0xFF;
		port       = htons(htonl(payload->data2) & 0xFFFF);

		service = find_service(server, protocol, port);

		if (service) {
			if (service->active) {
				log_printf(LOG_DEBUG,
				    "Received NECP_START for active "
				    "service. Ignoring.");
				return 0;
			} else {
				log_printf(LOG_DEBUG,
				    "Received NECP_START for quiescing "
				    "service. Reactivating.");
				service->active = 1;
				cancel_removal(service);
				payload++;
				continue;
			}
		}

		/* the service doesn't exist, create a new one */
		if (!(service = malloc(sizeof(struct service)))) {
			perror("malloc");
			return -1;
		}

		service->virtual_service =
			find_virtual_service(protocol, port);

		if (!service->virtual_service) {
			/* errors are marked by setting the error flag in
			 * the packet data, since we'll need to do this
			 * anyway, and the erroneous payload units are marked by
			 * a 1 in the data3 field */
			header->flags |= NECP_F_Error;
			payload->data3 = 1;
			log_printf(LOG_WARN,
			    "No virtual service found for server: %s "
			    "protocol: %d port: %d forwarding: %d. Rejecting.",
			    server->name, protocol, ntohs(port),
			    forwarding);
			payload++;
			continue;
		}

		/* establish the service object */
		service->server         = server;
		service->forwarding     = forwarding;
		service->port           = port;
		service->current_weight = 0;
		service->active         = 1;

		/* add the service to the ipvs tables */
		log_printf(LOG_INFO,
		    "Notifed of new service (server %s, port %d, protocol %d, "
		    "forwarding %d). Adding to ipvs tables.",
		    server->name, ntohs(port), protocol,
		    forwarding);
		ipvsif_add_service(service);

		/* add this to the list */
		list_add(&services, service);
		log_printf(LOG_VDEBUG, "Services list now has %d items",
	                services.count);

		/* make sure this payload doesn't get included in
		 * an error */
		payload->data3 = 0;

		payload++;
	}

	/* prepare an ACK */
	payload = last_payload = (struct necp_payload_unit *)
		(packet + sizeof(*header));

	if (header->flags & NECP_F_Error) {
		/* if there was an error, look through all the payload units
		 * to see which one(s) caused it. these are moved to the front
		 * of the payload array */
		while (((int)payload - (int)packet) < header->payload_len) {
			if (payload->data3) {
				if (payload != last_payload) {
					memcpy(last_payload, payload,
					       sizeof(*payload));
					last_payload->data3 = 0;
				}
				last_payload++;
			}
			payload++;
		}
	}

	/* schedule a keepalive if there was at least one payload that wasn't
	 * an error */
	if (header->payload_len != (int)last_payload -
	    (int)packet - sizeof(*header)) {
		schedule_keepalive(server);
	}

	/* return the ACK */
	header->opcode = NECP_START_ACK;
	header->payload_len = (int)last_payload -
	                      (int)packet - sizeof(*header);
	send_packet(server, packet);

	return 0;
}

int handle_necp_stop(struct server *server, char *packet)
{
	struct necp_header *header;
	struct necp_payload_unit *payload, *last_payload;
	struct service *service;

	log_printf(LOG_DEBUG, "Got NECP_STOP from %s", server->name);

	header = (struct necp_header *)packet;

	payload = (struct necp_payload_unit *)(packet + sizeof(*header));

	while (((int)payload - (int)packet) < header->payload_len) {
		uint8_t protocol;
		uint16_t port, forwarding;

		forwarding = ntohl(payload->data0) & 0xFFFF;
		protocol   = ntohl(payload->data1) & 0xFF;
		port       = htons(ntohl(payload->data2) & 0xFFFF);

		service = find_service(server, protocol, port);

		if (!service) {
			/* ignore for now, but should this cause an error? */
			log_printf(LOG_DEBUG,
			    "Received NECP_STOP for unknown service. "
			    "Ignoring");
			return 0;
		} else if (!service->active) {
			log_printf(LOG_DEBUG,
			    "Received NECP_STOP for quiescing service. "
			    "Ignoring.");
			return 0;
		} else {
			log_printf(LOG_INFO,
			    "Notified of quiescing service (server %s, "
			    "protocol %d, port %d). "
			    "Disabling service. ",
			    server->name, protocol, ntohs(port),
			    forwarding);

			/* stop allocations to this server */
			service->active = 0;
			service->current_weight = 0;
			ipvsif_set_weight(service);

			/* schedule the removal from the ipvs tables */
			if (conditional_service_removal(service)) {
				schedule_removal(service);
			} else {
				delete_service_entry(service);
			}

			/* make sure this payload doesn't get included in
			 * an error */
			payload->data3 = 0;
		}
		payload++;
	}

	/* prepare ACK */
	payload = last_payload = (struct necp_payload_unit *)
	                         (packet + sizeof(*header));

	/* there's nothing at the moment that would cause an error, but check
	 * anyway .. */
	if (header->flags & NECP_F_Error) {
		/* if there was an error, look through all the payload units
		 * to see which one(s) caused it. these are moved to the front
		 * of the payload array
		 */
		while (((int)payload - (int)packet) < header->payload_len) {
			if (payload->data3) {
				if (payload != last_payload) {
					memcpy(last_payload, payload,
					       sizeof(*payload));
					last_payload->data3 = 0;
				}
				last_payload++;
			}
			payload++;
		}
	}

	/* return the ACK */
	header->opcode = NECP_STOP_ACK;
	header->payload_len = (int)last_payload -
		              (int)packet - sizeof(*header);
	send_packet(server, packet);

	return 0;
}

int handle_necp_keepalive_ack(struct server *server, char *packet)
{
	struct necp_header *header;
	struct necp_payload_unit *payload;

	log_printf(LOG_DEBUG, "Got NECP_KEEPALIVE_ACK from %s", server->name);

	header = (struct necp_header *)packet;

	payload = (struct necp_payload_unit *)(packet + sizeof(*header));

	if (header->flags & NECP_F_Error) {
		/* handle keepalive errors how? */
	} else if (header->request_id != server->request_id) {
		log_printf(LOG_INFO, "KEEPALIVE ACK from %s had an invalid "
				"request id, ignoring", server->name);
	} else {
		uint8_t protocol;
		uint32_t port;
		struct service *service;

		payload = (struct necp_payload_unit *)
		          (packet + sizeof(*header));

		/* for each payload unit */
		while (((int)payload - (int)packet) < header->payload_len) {

			/* get protocol and port data from packet */
			protocol = ntohl(payload->data1) & 0xFF;
			port = htons(ntohl(payload->data2) & 0xFFFF);

			service = find_service(server, protocol, port);

			if (!service) {
				log_printf(LOG_WARN,
				    "Keepalive ACK received for unknown "
				    "service. Ignoring");
			} else {
				service->current_weight = ntohl(payload->data3)
					                  & 0xFF;
				log_printf(LOG_INFO,
				    "Service (server %s, protocol %d, "
				    "port %d) reported health of %d.",
				    server->name, protocol,
				    ntohs(port), service->current_weight);

				ipvsif_set_weight(service);
			}
			payload++;
		}
		server->retries = 0;
	}
	/* reset keepalive timeout */
	cancel_keepalive_timeout(server);

	/* schedule next keepalive */
	schedule_keepalive(server);

	return 0;
}

struct service *find_service(struct server *server, uint8_t protocol,
		uint16_t port)
{
	listitem *i;

	list_for_each(services, i) {
		struct service *service = (struct service *)i->item;
		if (service->server == server &&
			service->virtual_service->ipvs_service.protocol
				== protocol &&
		    service->port == port)
			return service;
	}
	return NULL;
}

static struct virtual_service *find_virtual_service(uint8_t protocol,
		uint16_t port)
{
	listitem *i;

	list_for_each(virtual_services, i) {
		struct virtual_service *service =
			(struct virtual_service *)i->item;
		if (service->ipvs_service.protocol == protocol &&
		    service->ipvs_service.port == port)
			return service;
	}
	return NULL;
}

void send_keepalive(struct server *server)
{
	listitem *i;
	struct service *service;
	char *packet;
	struct necp_header *header;
	int nservices;

	assert(server);

	/* if the server is not connected, don't try and send a keepalive */
	if (!(server->sd > 0))
		return;

	nservices = 0;

	if (!(packet = malloc(sizeof(*header)))) {
		perror("malloc");
		return;
	}

	header = (struct necp_header *)packet;

	/* basic parameters for the keepalive packet */
	header->flags = NECP_F_Basic_Payload;
	header->opcode = NECP_KEEPALIVE;

	/* send a keepalive for all services for this server */
	list_for_each(services, i) {
		service = (struct service *)i->item;
		if (service->server == server && service->active) {
			struct necp_payload_unit *payload;
			int realloc_size;

			nservices++;

			/* allocate memory for this payload unit */
			realloc_size = sizeof(*header) +
			              (nservices *
			               sizeof(*payload));

			if (!(packet = realloc(packet, realloc_size))) {
				perror("realloc");
				return;
			}

			/* need to reassign the header pointer, as we've done
			 * a realloc() */
			header = (struct necp_header *)packet;

			payload = (struct necp_payload_unit *)
			          (packet + sizeof(*header) +
			          ((nservices - 1) *
			           sizeof(*payload)));

			/* set the parameters for this service */
			memset(payload, 0, sizeof(*payload));
			payload->data0 = htonl(0x01);
			payload->data1 = htonl((uint32_t)service->
					virtual_service->ipvs_service.protocol);
			payload->data2 = ntohs(service->port);
			payload->data2 = htonl(payload->data2);
		}
	}

	/* set final header params */
	header->payload_len = nservices * sizeof(struct necp_payload_unit);
	header->request_id = server->request_id = rand();

	log_printf(LOG_DEBUG, "Sending NECP_KEEPALIVE to %s",
	               server->name);

	/* whee! */
	send_packet(server, packet);

	free(packet);
}

int deactivate_server(struct server *server)
{
	listitem *i;
	int nservices = 0;

	log_printf(LOG_DEBUG, "Attempting server removal for server %s.",
	               server->name);

	/* don't send any keepalives or keepalive timeouts */
	unschedule_server(server);

	/* close the connection if necessary */
	disconnect_server(server);

	/* check to see if there are services for this server. If there are,
	 * they will either be active or inactive. Active services are
	 * deactivated, then both types of service are scheduled for removal.
	 * The scheduler is then responsible for removing services from the
	 * the main list */
	i = services.head;

	while (i) {
		listitem *next = i->next;
		struct service *service = (struct service *)i->item;
		if (service->server == server) {
			service->current_weight = 0;
			ipvsif_set_weight(service);
			if (conditional_service_removal(service) > 0) {
				schedule_removal(service);
				nservices++;
			} else {
				delete_service_entry(service);
			}
		}
		i = next;
	}

	if (!nservices) {
		log_printf(LOG_INFO, "Removing server %s from servers list",
		              server->name);
		list_remove(&servers, server);
		log_printf(LOG_VDEBUG, "Servers list now has %d items",
		                 servers.count);
		free(server);
	} else {
		log_printf(LOG_DEBUG, "Server %s still has %d related "
				"services, not removing.", server->name,
			       nservices);
	}

	return nservices;
}

void delete_service_entry(struct service *service)
{
	log_printf(LOG_VDEBUG, "Deleting service entry (server: %s, port %d)",
			service->server->name, ntohs(service->port));
	list_remove(&services, service);
	log_printf(LOG_VDEBUG, "Services list now has %d items",
			services.count);
	free(service);
}

int conditional_service_removal(struct service *service)
{
	int nconns;
	service->active = 0;

	if ((nconns = ipvsif_get_service_conns(service)) <= 0) {
		log_printf(LOG_INFO,
		    "Service (server %s, protocol %d, port "
		    "%d) has 0 active connections, removing.",
		    service->server->name,
		    service->virtual_service->ipvs_service.protocol,
		    ntohs(service->port));
		ipvsif_remove_service(service);
		return 0;
	}
	return nconns;
}

int conditional_server_removal(struct server *server)
{
	int nservices = 0;
	listitem *i;

	list_for_each(services, i) {
		struct service *service = (struct service *)(i->item);
		if (service->server == server)
			nservices++;
	}

	if (!nservices) {
		log_printf(LOG_INFO, "Server %s has no active services, "
				"removing", server->name);
		/* just to be safe - make sure there are no tasks pending */
		unschedule_server(server);
		list_remove(&servers, server);
		log_printf(LOG_VDEBUG, "Servers list now has %d items",
				servers.count);
		free(server);
	}

	return nservices;
}
