#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <syslog.h>
#include <errno.h>
#include <string.h>
#include <signal.h>
#include <rpc/rpc.h>
#include <sys/queue.h>
#include <rpcsvc/nlm_prot.h>
#include <rpcsvc/sm_inter.h>
#include "lock_common.h"

/* list of hosts we monitor */
LIST_HEAD(hostlst_head, host);
struct hostlst_head hostlst_head = LIST_HEAD_INITIALIZER(hostlst_head);

/* struct describing a monitored host */
struct host {
	LIST_ENTRY(host) hostlst;
	char name[SM_MAXSTRLEN+1];
	int state;
	int refcnt;
};

int sm_state /* local NSM state */;
CLIENT *sm_client;
struct mon sm_mon; /* for SM_MON/SM_UNMON/SM_UNMON_ALL calls */

void
siglock()
{
	sigset_t block;
	
	sigemptyset(&block);
	sigaddset(&block, SIGCHLD);

	if (sigprocmask(SIG_BLOCK, &block, NULL) < 0) {
		syslog(LOG_WARNING, "siglock failed: %s", strerror(errno));
	}
}

void
sigunlock()
{
	sigset_t block;
	
	sigemptyset(&block);
	sigaddset(&block, SIGCHLD);

	if (sigprocmask(SIG_UNBLOCK, &block, NULL) < 0) {
		syslog(LOG_WARNING, "sigunlock failed: %s", strerror(errno));
	}
}

/* monitor a host through rpc.statd, and keep a ref count */
void
do_mon(hostname)
	char *hostname;
{
	struct host *hp;
	struct sm_stat_res stat_res;
	struct timeval timeo = { 3, 0 };
	enum clnt_stat clnt_stat;

	LIST_FOREACH(hp, &hostlst_head, hostlst) {
		if (strcmp(hostname, hp->name) == 0) {
			/* already monitored, just bump refcnt */
			hp->refcnt++;
			return;
		}
	}
	/* not found, have to create an entry for it */
	hp = malloc(sizeof(struct host));
 	if (hp == NULL) {
 		syslog(LOG_WARNING, "can't monitor host %s: malloc failed", hostname);
 		return;
	}
	strlcpy(hp->name, hostname, sizeof(hp->name));
	hp->refcnt = 1;
	syslog(LOG_DEBUG, "monitoring host %s", hostname);
	sm_mon.mon_id.mon_name = hostname;
	clnt_stat = clnt_call(sm_client, SM_MON, xdr_mon, (char*)&sm_mon, xdr_sm_stat_res, (char*)&stat_res, timeo);
	sm_mon.mon_id.mon_name = NULL;
	if (clnt_stat != RPC_SUCCESS) {
		syslog(LOG_ERR, clnt_sperror(sm_client, "do_mon: rpc failed"));
		free(hp);
	} else if (stat_res.res_stat == stat_fail) {
		syslog(LOG_ERR, "do_mon: statd failed");
		free(hp);
	} else {
		hp->state = stat_res.state;
		LIST_INSERT_HEAD(&hostlst_head, hp, hostlst);
	}
}

void
do_unmon(hostname)
	char *hostname;
{
	struct host *hp;
	struct sm_stat stat;
	struct timeval timeo = { 3, 0 };
	enum clnt_stat clnt_stat;

	LIST_FOREACH(hp, &hostlst_head, hostlst) {
		if (strcmp(hostname, hp->name) == 0) {
			hp->refcnt--;
			if (hp->refcnt == 0) goto unmon;
			return;
		}
	}
 	syslog(LOG_WARNING, "can't unmonitor host %s: not monitoring",
 		hostname);
 	return;
unmon:	syslog(LOG_DEBUG, "unmonitoring host %s", hostname);
	sm_mon.mon_id.mon_name = hostname;
	clnt_stat = clnt_call(sm_client, SM_UNMON, xdr_mon_id, (char*)&sm_mon.mon_id, xdr_sm_stat, (char*)&stat, timeo);
	sm_mon.mon_id.mon_name = NULL;
	if (clnt_stat != RPC_SUCCESS) {
		syslog(LOG_ERR, clnt_sperror(sm_client, "do_unmon: rpc failed"));
	} else {
		LIST_REMOVE(hp, hostlst);
		free(hp);
	}
}

void
common_notify(hostname, state)
	char *hostname;
	int state;
{
	struct host *hp;
	
	syslog(LOG_DEBUG, "notify from %s, new state %d", hostname, state);
	for (hp = LIST_FIRST(&hostlst_head); hp != NULL; hp = LIST_NEXT(hp, hostlst)) {
		if (strcmp(hostname, hp->name) == 0) {
			if (state != hp->state) {
			 	syslog(LOG_DEBUG, "old state %d, removing", hp->state);
			 	LIST_REMOVE(hp, hostlst);
			 	free(hp);
			} else {
			 	syslog(LOG_INFO, "received notification from host %s, but state is still %d", hostname, state);
			}
			return;
		}
	}
	syslog(LOG_NOTICE, "received notification from host %s, state %d, but not on monitor list", hostname, state);
}

void common_sigchild(int s)
{
}

void common_poll(void)
{
}

void common_init(void)
{
	struct sm_stat sm_stat;
	enum clnt_stat clnt_stat;
	struct timeval timeo = { 10, 0 };
	
	syslog(LOG_DEBUG, "creating client handle for local NSM");
	for (;;) {
		sm_client = clnt_create("localhost", SM_PROG, SM_VERS, "udp");
		if (sm_client != NULL)
			break;
		if (rpc_createerr.cf_stat == RPC_PROGNOTREGISTERED) {
			syslog(LOG_INFO, "waiting for statd to register");
			sleep(1);
			continue;
		}
		syslog(LOG_ERR, clnt_spcreateerror("clnt_create() for local NSM failed"));
		exit(1);
	}

	sm_mon.mon_id.my_id.my_name = "localhost";
	sm_mon.mon_id.my_id.my_prog = 0;
	sm_mon.mon_id.my_id.my_vers = 0;
	sm_mon.mon_id.my_id.my_proc = 0;

	for (;;) {
		syslog(LOG_DEBUG, "calling SM_UNMON_ALL");
		clnt_stat = clnt_call(sm_client, SM_UNMON_ALL, xdr_my_id, (char*)&sm_mon.mon_id.my_id, xdr_sm_stat, (char*)&sm_stat, timeo);
		if (clnt_stat == 0)
			break;
		/* XXX can this really happen? */
		if (clnt_stat == RPC_PROGUNAVAIL) {
			syslog(LOG_INFO, "waiting for statd to come up");
			sleep(1);
			continue;
		} 
		syslog(LOG_ERR, clnt_sperror(sm_client, "clnt_call() for local NSM failed"));
		exit(1);
	}
	sm_state = sm_stat.state;
	syslog(LOG_DEBUG, "local NSM state %d", sm_state);
	
	/* for future calls to SM_LOCK/SM_UNLOCK */
	memset(&sm_mon.priv, 0, sizeof(&sm_mon.priv));
	sm_mon.mon_id.mon_name = NULL;
	sm_mon.mon_id.my_id.my_prog = NLM_PROG;
	sm_mon.mon_id.my_id.my_vers = NLM_SM;
	sm_mon.mon_id.my_id.my_proc = NLM_SM_NOTIFY;
}
