/* $Id: authclnt.C,v 1.57 2002/12/01 02:45:25 dm Exp $ */

/*
 *
 * Copyright (C) 2002 David Mazieres (dm@uun.org)
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

#include "sfscrypt.h"
#include "sfsschnorr.h"
#include "sfsauthd.h"
#include "rxx.h"
#include "sfskeymisc.h"

#if HAVE_GETSPNAM
#include <shadow.h>
#endif /* HAVE_GETSPNAM */

extern "C" char *crypt (const char *, const char *);
sprivk_tab_t sprivk_tab;

static str
hash_sprivk (const sfs_2schnorr_priv_xdr &k)
{
  sfs_hash h;
  if (!sha1_hashxdr (&h, k))
    return NULL;
  return str (h.base (), h.size ());
}

bool
sprivk_tab_t::is_valid (const str &hv)
{
  assert (hv);
  bool ret;
  sprivk_t *s = keys[hv];
  if (!s)
    ret = false;
  else 
    ret = s->valid;
  return ret;
}

bool
sprivk_tab_t::invalidate (const str &hv)
{
  assert (hv);
  sprivk_t *s = keys[hv];
  if (!s)
    return false;
  s->valid = false;
  release (hv, s);
  return true;
}

void
sprivk_tab_t::bind (const str &hv)
{
  
  assert (hv);
  sprivk_t *s = keys[hv];
  if (s)
    s->refs++;
  else {
    nentries ++;
    keys.insert (hv);
  }
}

void
sprivk_tab_t::release (const str &hv, sprivk_t *s)
{
  assert (hv);
  if (!s)
    s = keys[hv];
  if (s && --s->refs == 0) {
    nentries --;
    keys.remove (hv);
  }
}

bool
validshell (const char *shell)
{
  const char *s;

  setusershell ();
  while ((s = getusershell ()))
    if (!strcmp(s, shell)) {
      endusershell ();
      return true;
    }
  endusershell ();
  return false;
}

void
authclnt::urecfree (urec_t *u)
{
  utab.remove (u);
  ulist.remove (u);
  delete u;
}

authclnt::urec_t::~urec_t ()
{
  if (kh.type == SFSAUTH_KEYHALF_PRIV) 
    for (u_int i = 0; i < kh.priv->size (); i++) 
      sprivk_tab.release (hash_sprivk ((*kh.priv) [i]));
}

authclnt::urec_t::urec_t (u_int32_t a, sfs_authtype t, 
			  const sfsauth_dbrec &dbr)
  : authno (a), authtype (t)
{
  if (dbr.type == SFSAUTH_USER) {
    uname = dbr.userinfo->name;
    kh = dbr.userinfo->srvprivkey;
    if (kh.type == SFSAUTH_KEYHALF_PRIV)
      for (u_int i = 0; i < kh.priv->size (); i++)
	sprivk_tab.bind (hash_sprivk ((*kh.priv) [i]));
  }
}

authclnt::authclnt (ref<axprt_crypt> x, const authunix_parms *aup)
  : sfsserv (x),
    authsrv (asrv::alloc (x, sfsauth_prog_2,
			  wrap (this, &authclnt::dispatch))),
    sfsauth_login_srp (aup && !aup->aup_uid)
{
  if (aup) {
    uid.alloc ();
    *uid = aup->aup_uid;
    if (!client_name || !client_name.len ()) {
      if (!*uid)
	client_name = "LOCAL";
      else
	client_name = strbuf ("LOCAL(uid=%d)", *uid);
    }
  }
}

authclnt::~authclnt ()
{
  ulist.traverse (wrap (this, &authclnt::urecfree));
}

ptr<sfspriv>
authclnt::doconnect (const sfs_connectarg *ci,
		     sfs_servinfo *si)
{
  *si = myservinfo;
  return myprivkey;
}

inline str
mkname (const dbfile *dbp, str name)
{
  str r;
  if (dbp->prefix)
    r = dbp->prefix << "/" << name;
  else
    r = name;
  return name;
}


bool
authclnt::setuser (sfsauth2_loginres *resp, const sfsauth_dbrec &ae,
		   const dbfile *dbp)
{
  assert (ae.type == SFSAUTH_USER);

  resp->set_status (SFSLOGIN_OK);
  resp->resok->creds.setsize (1);
  resp->resok->creds[0].set_type (SFS_UNIXCRED);

  str name = mkname (dbp, ae.userinfo->name);

  resp->resok->creds[0].unixcred->uid = dbp->uidmap->map (ae.userinfo->id);
  if (resp->resok->creds[0].unixcred->uid == badid) {
    resp->set_status (SFSLOGIN_BAD);
    *resp->errmsg = "uid out of range";
    return false;
  }

  resp->resok->creds[0].unixcred->username = name;
  resp->resok->creds[0].unixcred->homedir = "/dev/null";
  resp->resok->creds[0].unixcred->shell = "/dev/null";
  if (dbp->allow_unix_pwd && name.len () == strlen (name)) {
    if (struct passwd *pw = getpwnam (ae.userinfo->name)) {
      resp->resok->creds[0].unixcred->homedir = pw->pw_dir;
      resp->resok->creds[0].unixcred->shell = pw->pw_shell;
    }
  }
#if 0
  /* XXX - what is this -dm?  Seems like a huge security hole. */
  else if (dbp->allow_userdir_shell) {
    // Allow remote users without a local account to specify a 
    // shell and homedir
    resp->resok->creds[0].unixcred->homedir = USERDIR_HOMEDIR;
    resp->resok->creds[0].unixcred->shell = USERDIR_SHELL;
  }
#endif

  resp->resok->creds[0].unixcred->gid = dbp->gidmap->map (ae.userinfo->gid);

  vec<u_int32_t> groups;
  findgroups (&groups, name);
  resp->resok->creds[0].unixcred->groups.setsize (groups.size ());
  memcpy (resp->resok->creds[0].unixcred->groups.base (),
	  groups.base (), groups.size () * sizeof (groups[0]));

  return true;
}

void
setuser_pkhash (sfsauth2_loginres *resp, ptr<sfspub> vrfy)
{
  str h;
  if (!(h = vrfy->get_pubkey_hash ())) {
    warn << "Error in sha1_hashxdr of user's public key\n";
    return;
  }

  vec<sfsauth_cred> v;
  size_t n = resp->resok->creds.size ();
  sfsauth_cred *cp = resp->resok->creds.base ();

  for (size_t i = 0; i < n; i++)
    v.push_back (cp[i]);

  v.push_back ();
  v[n].set_type (SFS_PKCRED);
  *v[n].pkhash = armor32 (h);

  resp->resok->creds.setsize (n + 1);
  for (size_t i = 0; i < n + 1; i++)
    resp->resok->creds[i] = v[i];
}

void
authclnt::findgroups (vec<u_int32_t> *groups, str name)
{
  groups->clear ();
  for (dbfile *dbp = dbfiles.base (); dbp < dbfiles.lim (); dbp++) {
    str suffix;
    if (!dbp->prefix)
      suffix = name;
    else if (dbp->prefix.len () + 1 < name.len ()
	     && name[dbp->prefix.len ()] == '/'
	     && !memcmp (dbp->prefix.cstr (), name.cstr (),
			 dbp->prefix.len ()))
      suffix = substr (name, dbp->prefix.len () + 1);
    else
      continue;
    ptr<authcursor> ac = dbp->db->open ();
    vec<u_int32_t> gv;
    ac->getgroups (&gv, suffix);
    while (!gv.empty ()) {
      u_int32_t gid = dbp->gidmap->map (gv.pop_front ());
      if (gid != badid)
	groups->push_back (gid);
    }
  }
}


bool
authclnt::checkreq (sfsauth2_loginres *resp, const sfsauth2_loginarg *lap,
		    const sfs_authreq2 *reqp)
{
  if (reqp->type != SFS_SIGNED_AUTHREQ && 
      reqp->type != SFS_SIGNED_AUTHREQ_NOCRED) {
    resp->set_status (SFSLOGIN_BAD);
    *resp->errmsg = strbuf ("bad sfs_authreq2 type %d\n", reqp->type);
  }
  else if (reqp->authid != lap->authid) {
    resp->set_status (SFSLOGIN_BAD);
    *resp->errmsg = "authid mismatch\n";
  }
  else if (reqp->seqno != lap->arg.seqno) {
    resp->set_status (SFSLOGIN_BAD);
    *resp->errmsg = "sequence number mismatch\n";
  }
  else
    return true;
  return false;
}

#if 0
ptr<aclnt>
authclnt::getauthclnt ()
{
  /* XXX - Bit of a kludge; should really just override sfs_login */
  static ptr<aclnt> c;
  static ptr<authclnt> ac;
  if (c)
    return c;
  int fds[2];
  if (socketpair (AF_UNIX, SOCK_STREAM, 0, fds) < 0)
    fatal ("socketpair: %m\n");
  c = aclnt::alloc (axprt_stream::alloc (fds[0]), sfsauth_prog_2);
  ac = New refcounted<authclnt> (axprt_crypt::alloc (fds[1]),
				 (authunix_parms *) NULL, true);
  return c;
}
#endif

void
authclnt::dispatch (svccb *sbp)
{
  if (!sbp) {
    delete this;
    return;
  }
  switch (sbp->proc ()) {
  case SFSAUTH2_NULL:
    sbp->reply (NULL);
    break;
  case SFSAUTH2_LOGIN:
    {
      sfsauth2_loginres res;
      sfsauth_login (&res, sbp->template getarg<sfsauth2_loginarg> (),
		     sfsauth_login_srp);
      sbp->replyref (res);
      break;
    }
  case SFSAUTH2_QUERY:
    sfsauth_query (sbp);
    break;
  case SFSAUTH2_UPDATE:
    sfsauth_update (sbp);
    break;
  case SFSAUTH2_SIGN:
    sfsauth_sign (sbp);
    break;
  default:
    sbp->reject (PROC_UNAVAIL);
    break;
  }
}

void
authclnt::sfsauth_sign (svccb *sbp)
{
  sfsauth2_sign_arg *arg = sbp->template getarg<sfsauth2_sign_arg> ();
  sfsauth2_sign_res res (true);
  u_int32_t authno = sbp->getaui ();
  sfsauth_dbrec db;
  bool restricted_sign = false;
  urec_t *ur = NULL;
  sfsauth_keyhalf *kh = NULL;
  sfs_idname uname;

  if (authno && (ur = utab[authno])) {
    kh = &ur->kh;
    uname = ur->uname;
  }

  if (!kh && arg->req.type == SFS_SIGNED_AUTHREQ && 
      arg->req.authreq->type == SFS_SIGNED_AUTHREQ_NOCRED &&
      arg->authinfo.service == SFS_AUTHSERV &&
      authid == arg->req.authreq->authid) {
    sfsauth_dbkey key (SFSAUTH_DBKEY_NAME);
    if ((*key.name = arg->user) && get_user_cursor (NULL, NULL, &db, key)) {
      kh = &db.userinfo->srvprivkey;
      uname = db.userinfo->name;
      restricted_sign = true;
    }
  } 

  if (!kh || kh->type != SFSAUTH_KEYHALF_PRIV) {
    res.set_ok (false);
    *res.errmsg = "No valid server private keyhalf for user";
    sbp->replyref (res);
    return;
  }
  if (arg->presig.type != SFS_2SCHNORR) {
    res.set_ok (false);
    *res.errmsg = "Can only answer 2-Schnorr requests";
    sbp->replyref (res);
    return;
  }
  res.sig->set_type (SFS_SCHNORR);
  int i = sfs_schnorr_pub::find (*kh, arg->pubkeyhash);
  if (i < 0) {
    res.set_ok (false);
    *res.errmsg = "No matching keyhalf found on server.";
    sbp->replyref (res);
    return;
  }
  const sfs_2schnorr_priv_xdr &spriv = (*kh->priv)[i];
  if (ur && !sprivk_tab.is_valid (hash_sprivk (spriv))) {
    res.set_ok (false);
    *res.errmsg = "Server keyhalf is no longer valid.";
    sbp->replyref (res);
    return;
  }

  ref<schnorr_srv_priv> srv_key = New refcounted<schnorr_srv_priv> 
    (spriv.p, spriv.q, spriv.g, spriv.y, spriv.x);

  str msg = sigreq2str (arg->req);
  if (!msg) {
    res.set_ok (false);
    *res.errmsg = "Cannot marshal signature request";
    sbp->replyref (res);
    return ;
  }

  sfs_hash aid_tmp;
  if (arg->req.type != SFS_NULL && 
      (!sha1_hashxdr (aid_tmp.base (), arg->authinfo) || 
      !sigreq_authid_cmp (arg->req, aid_tmp))) {
    res.set_ok (false);
    *res.errmsg = "Incorrect authid in request";
    sbp->replyref (res);
    return ;
  }

  if (!siglog (siglogline (*arg, uname))) {
    res.set_ok (false);
    *res.errmsg = "Refusing to sign: could not log signature";
    sbp->replyref (res);
    return;
  }
    
  srv_key->endorse_signature (&res.sig->schnorr->r, &res.sig->schnorr->s, 
			      msg, arg->presig.schnorr->r);
  sbp->replyref (res);
}

str
authclnt::siglogline (const sfsauth2_sign_arg &arg, const str &uname)
{
  str req = xdr2str (arg);
  if (!req) return NULL;
  req = armor64 (req);
  str tm = single_char_sub (timestr (), ':', ".");
  strbuf line;
  line << "SIGN:" << uname << ":" << client_name << ":" << tm << ":" 
       << req << "\n";
  return line;
}

bool
siglog (const str &line)
{
  if (!line) return false;
  int n = write (logfd, line.cstr (), line.len ());
  if (n < int (line.len ())) 
    return false;
  return true;
}

str 
siglog_startup_msg ()
{
  strbuf msg;
  str tm = single_char_sub (timestr (), ':', ".");
  msg << "sfsauthd restarted: " << tm << "\n";
  return msg;
}

void
siglogv ()
{
  if (!siglog (siglog_startup_msg ()))
    fatal << "Cannot generate startup message for signature log\n";
}

void
authclnt::sfsauth_update (svccb *sbp)
{
  const sfsauth_cred *cp = NULL;
  urec_t *ur = NULL;
  if (sbp->getaui () >= credtab.size ()
      || !(cp = &credtab[sbp->getaui ()]) 
      || cp->type != SFS_UNIXCRED 
      || !(ur = utab[sbp->getaui ()])
      || ur->authtype == SFS_NOAUTH) {
    sbp->reject (AUTH_REJECTEDCRED);
    return;
  }

  sfsauth2_update_res res (false);
  sfsauth_dbkey kname (SFSAUTH_DBKEY_NAME);
  *kname.name = cp->unixcred->username;
  bool oldsig = false;

  sfsauth_dbrec cdbr;
  dbfile *cdbp;
  if (!get_user_cursor (&cdbp, NULL, &cdbr, kname)
      || cp->unixcred->username != cdbr.userinfo->name
      || cp->unixcred->uid != cdbr.userinfo->id) {
    *res.errmsg = "could not load credential db record";
    sbp->replyref (res);
    return;
  }

  sfsauth2_update_arg *argp = sbp->template getarg<sfsauth2_update_arg> ();
  if (argp->req.type != SFS_UPDATEREQ
      || (argp->req.rec.type != SFSAUTH_USER
	  && argp->req.rec.type != SFSAUTH_GROUP)) {
    *res.errmsg = "invalid request";
    sbp->replyref (res);
    return;
  }
  u_int32_t opts = argp->req.opts;
  if (argp->req.authid != authid) {
    *res.errmsg = "invalid authid";
    sbp->replyref (res);
    return ;
  }

  static rxx adminrx ("(\\A|,)admin(\\z|,)");
  bool admin = cdbp->allow_admin && adminrx.match (cdbr.userinfo->privs);
  str reqxdr = xdr2str (argp->req);
  if (argp->newsig) {
    str e;
    if (!sfscrypt.verify (argp->req.rec.userinfo->pubkey, *(argp->newsig), 
			  reqxdr, &e)) {
      *res.errmsg = str (strbuf ("new signature: " << e));
      sbp->replyref (res);
      return;
    }
  }
  else if (!(opts & SFSUP_KPPK) && !admin) {
    *res.errmsg = "Missing signature with new public key.";
    sbp->replyref (res);
    return;
  }

  if (argp->authsig) {
    str e;
    if (!sfscrypt.verify (cdbr.userinfo->pubkey, *(argp->authsig), 
			  reqxdr, &e)) {
      *res.errmsg = str (strbuf ("old signature: " << e));
      sbp->replyref (res);
      return;
    } else 
      oldsig = true;
  }
  else if (!cdbp->allow_unix_pwd || ur->authtype != SFS_UNIXPWAUTH) {
    *res.errmsg = "digital signature required";
    sbp->replyref (res);
    return;
  }
  else
    admin = false;

  dbfile *udbp;
  ptr<authcursor> uac;
  sfsauth_dbrec udbr;
  if (argp->req.rec.type == SFSAUTH_USER) {
    *kname.name = argp->req.rec.userinfo->name;
    if (!get_user_cursor (&udbp, &uac, &udbr, kname, true)) {
      *res.errmsg = "could not find or update user's record";
      sbp->replyref (res);
      return;
    }
    if (!admin && (udbr.userinfo->name != cdbr.userinfo->name
		   || udbr.userinfo->id != cdbr.userinfo->id)) {
      /* XXX - ignoring owner field for now */
      *res.errmsg = "access denied";
      sbp->replyref (res);
      return;
    }
    if (argp->req.rec.userinfo->vers < 1) {
      *res.errmsg = "version number of record must be greater than 0";
      sbp->replyref (res);
      return;
    }
    if (argp->req.rec.userinfo->vers != udbr.userinfo->vers + 1) {
      *res.errmsg = "version mismatch";
      sbp->replyref (res);
      return;
    }
    uac->ae.userinfo->vers = argp->req.rec.userinfo->vers;
    if (!(opts & SFSUP_KPPK))
      uac->ae.userinfo->pubkey = argp->req.rec.userinfo->pubkey;
    if (!(opts & SFSUP_KPSRP))
      uac->ae.userinfo->pwauth = argp->req.rec.userinfo->pwauth;
    if (!(opts & SFSUP_KPESK))
      uac->ae.userinfo->privkey = argp->req.rec.userinfo->privkey;

    str err = update_srv_keyhalf (argp->req.rec.userinfo->srvprivkey,
				  uac->ae.userinfo->srvprivkey, 
				  udbr.userinfo->srvprivkey, true, ur);
    if (err) {
      *res.errmsg = err;
      sbp->replyref (res);
      return;
    }

    strbuf sb;
    sb << "Last modified " << timestr () << " by " ;
    if (uid && !*uid)
      sb << "*superuser*";
    else
      sb << cp->unixcred->username;
    sb << "@" << client_name;
    uac->ae.userinfo->audit = sb;

    if (admin) {
      u_int32_t gid = argp->req.rec.userinfo->gid;
      if (udbp->gidmap && gid != udbp->gidmap->map (uac->ae.userinfo->gid)) {
	gid = udbp->gidmap->unmap (gid);
	if (gid == badid
	    || udbp->gidmap->map (gid) != argp->req.rec.userinfo->gid) {
	  *res.errmsg = "bad gid";
	  sbp->replyref (res);
	  return;
	}
      }
      uac->ae.userinfo->gid = gid;
      uac->ae.userinfo->privs = argp->req.rec.userinfo->privs;
    }

    if (!uac->update (cdbp->allow_create)) {
      *res.errmsg = "database refused update";
      sbp->replyref (res);
      return;
    }
    res.set_ok (true);
    cdbp->mkpub ();
  }
  else {
    *res.errmsg = "group updates not implemented";
    sbp->replyref (res);
    return;
  }

  sbp->replyref (res);
}

bool
authclnt::get_user_cursor (dbfile **dbpp, ptr<authcursor> *acp,
			   sfsauth_dbrec *dbrp, const sfsauth_dbkey &key,
			   bool writable)
{
  if (key.type != SFSAUTH_DBKEY_NAME && key.type != SFSAUTH_DBKEY_ID
      && key.type != SFSAUTH_DBKEY_PUBKEY) {
    if (dbrp) {
      dbrp->set_type (SFSAUTH_ERROR);
      *dbrp->errmsg = strbuf ("unsupported key type %d", key.type);
    }
    return false;
  }
  ptr<sfspub> pk;
  for (dbfile *dbp = dbfiles.base (); dbp < dbfiles.lim (); dbp++) {
    if (writable && !dbp->allow_update)
      continue;
    ptr<authcursor> ac = dbp->db->open (writable);
    if (!ac)
      continue;
    switch (key.type) {
    case SFSAUTH_DBKEY_NAME:
      {
	str name = dbp->strip_prefix (*key.name);
	if (!name || !ac->find_user_name (name)) {
	  struct passwd *pw;
	  if (dbp->allow_unix_pwd && (pw = getpwnam (name))) {
	    sfsauth_dbrec rec (SFSAUTH_USER);
	    rec.userinfo->name = *key.name;
	    rec.userinfo->id = pw->pw_uid;
	    rec.userinfo->vers = 0;
	    rec.userinfo->gid = pw->pw_gid;
	    ac->ae = rec;
	    break;
	  }
	  continue;
	}
	break;
      }
    case SFSAUTH_DBKEY_ID:
      {
	u_int32_t id = dbp->uidmap ? dbp->uidmap->unmap (*key.id) : *key.id;
	if (id == badid || !ac->find_user_uid (id)) {
	  struct passwd *pw;
	  if (dbp->allow_unix_pwd && (pw = getpwuid (*key.id))) {
	    sfsauth_dbrec rec (SFSAUTH_USER);
	    if (dbp->prefix)
	      rec.userinfo->name = dbp->prefix << "/" << pw->pw_name;
	    else
	      rec.userinfo->name = pw->pw_name;
	    rec.userinfo->id = pw->pw_uid;
	    rec.userinfo->vers = 0;
	    rec.userinfo->gid = pw->pw_gid;
	    ac->ae = rec;
	    break;
	  }
	  continue;
	}
	break;
      }
    case SFSAUTH_DBKEY_PUBKEY:
      if (!pk) {
	if (!(pk = sfscrypt.alloc (*key.key))) {
	  warn << "Cannot import user public key.\n";
	  return false;
	}
      }
      if (!(ac->find_user_pubkey (*pk) && *pk == ac->ae.userinfo->pubkey))
	continue;
      break;
    default:
      panic ("unreachable\n");
    }
    if (dbpp)
      *dbpp = dbp;
    if (acp)
      *acp = ac;
    if (dbrp) {
      *dbrp = ac->ae;
      dbrp->userinfo->pwauth = "";
      dbrp->userinfo->privkey.setsize (0);
      if (dbp->prefix)
	dbrp->userinfo->name = dbp->prefix << "/" << dbrp->userinfo->name;
      if (dbp->uidmap)
	dbrp->userinfo->id = dbp->uidmap->map (dbrp->userinfo->id);
      if (dbrp->userinfo->id == badid)
	continue;
      if (dbp->gidmap)
	dbrp->userinfo->gid = dbp->gidmap->map (dbrp->userinfo->gid);
    }
    return true;
  }
  if (dbrp) {
    dbrp->set_type (SFSAUTH_ERROR);
    *dbrp->errmsg = "not found";
  }
  return false;
}

void
authclnt::query_user (svccb *sbp)
{
  sfsauth2_query_arg *arg = sbp->template getarg<sfsauth2_query_arg> ();
  ptr<authcursor> ac;
  sfsauth2_query_res res;

  if (get_user_cursor (NULL, &ac, &res, arg->key)
      && sbp->getaui () < credtab.size ()) {
    const sfsauth_cred &c = credtab[sbp->getaui ()];
    const urec_t *ur = utab[sbp->getaui ()];
    if (ur && c.type == SFS_UNIXCRED &&
	c.unixcred->username == res.userinfo->name &&
	c.unixcred->uid == res.userinfo->id && ur->authtype == SFS_SRPAUTH) 
      res.userinfo->privkey = ac->ae.userinfo->privkey;
  }
  if (res.type == SFSAUTH_USER && 
      res.userinfo->srvprivkey.type == SFSAUTH_KEYHALF_PRIV)
    res.userinfo->srvprivkey.set_type (SFSAUTH_KEYHALF_FLAG);
  
  sbp->replyref (res);
}

void
authclnt::query_srpparms (svccb *sbp)
{
  sfsauth2_query_res res (SFSAUTH_SRPPARMS);
  if (!srpparms) {
    res.set_type (SFSAUTH_ERROR);
    *res.errmsg = "No SRP information available";
  } else {
    res.srpparms->parms = srpparms;
  }
  sbp->replyref (res);
}

void
authclnt::query_certinfo (svccb *sbp)
{
  sfsauth2_query_res res (SFSAUTH_CERTINFO);
  if (sfsauthrealm.len () > 0) {
    res.certinfo->name = sfsauthrealm;
    res.certinfo->info.set_status (SFSAUTH_CERT_REALM);
    res.certinfo->info.certpaths->set (sfsauthcertpaths.base (), 
				       sfsauthcertpaths.size (), 
				       freemode::NOFREE);
  }
  else {
    res.certinfo->name = myservinfo.cr7->host.hostname;
    res.certinfo->info.set_status (SFSAUTH_CERT_SELF);
  }
  sbp->replyref (res);
  return ;

}

void
authclnt::query_group (svccb *sbp)
{
  sfsauth2_query_arg *arg = sbp->template getarg<sfsauth2_query_arg> ();
  sfsauth2_query_res res;
  if (arg->key.type != SFSAUTH_DBKEY_NAME
      && arg->key.type != SFSAUTH_DBKEY_ID) {
    res.set_type (SFSAUTH_ERROR);
    *res.errmsg = strbuf ("unsupported key type %d", arg->key.type);
    sbp->replyref (res);
    return;
  }
  for (dbfile *dbp = dbfiles.base (); dbp < dbfiles.lim (); dbp++) {
    ptr<authcursor> ac = dbp->db->open ();
    if (!ac)
      continue;
    switch (arg->key.type) {
    case SFSAUTH_DBKEY_NAME:
      {
	str name = dbp->strip_prefix (*arg->key.name);
	if (!name || !ac->find_group_name (name))
	  continue;
	break;
      }
    case SFSAUTH_DBKEY_ID:
      {
	u_int32_t id = dbp->uidmap ? dbp->uidmap->unmap (*arg->key.id)
	  : *arg->key.id;
	if (id == badid || !ac->find_group_gid (id))
	  continue;
	break;
      }
    default:
      panic ("unreachable");
    }
    res = ac->ae;
    assert (res.type == SFSAUTH_GROUP);
    sbp->replyref (res);
    return;
  }
  res.set_type (SFSAUTH_ERROR);
  *res.errmsg = "not found";
  sbp->replyref (res);
}


void
authclnt::sfsauth_query (svccb *sbp)
{
  sfsauth2_query_arg *arg = sbp->template getarg<sfsauth2_query_arg> ();
  switch (arg->type) {
  case SFSAUTH_USER:
    query_user (sbp);
    break;
  case SFSAUTH_GROUP:
    query_group (sbp);
    break;
  case SFSAUTH_CERTINFO:
    query_certinfo (sbp);
    break;
  case SFSAUTH_SRPPARMS:
    query_srpparms (sbp);
    break;
  default:
    sfsauth2_query_res res;
    res.set_type (SFSAUTH_ERROR);
    *res.errmsg = strbuf ("unsupported query type %d", arg->type);
    sbp->replyref (res);
    break;
  }
}

void
authclnt::sfs_login (svccb *sbp)
{
  if (!authid_valid) {
    sbp->replyref (sfs_loginres (SFSLOGIN_ALLBAD));
    return;
  }
  sfsauth2_loginarg la;
  la.arg = *sbp->template getarg<sfs_loginarg> ();
  la.authid = authid;
  la.source = client_name << "!" << progname;

  sfsauth_dbrec dbr;
  sfsauth2_loginres lr;
  sfs_authtype authtype;
  sfsauth_login (&lr, &la, true, &dbr, &authtype);

  sfs_loginres res (lr.status);
  switch (lr.status) {
  case SFSLOGIN_OK:
    if (!seqstate.check (la.arg.seqno)
	|| lr.resok->creds.size () < 1)
      res.set_status (SFSLOGIN_BAD);
    else {
      u_int32_t authno;
      authno = authalloc (lr.resok->creds.base (), 
			  lr.resok->creds.size ());
      if (!authno) {
	warn << "ran out of authnos (or bad cred type)\n";
	res.set_status (SFSLOGIN_BAD);
      }
      utab_insert (authno, authtype, dbr);
      *res.authno = authno;
    }
    break;
  case SFSLOGIN_MORE:
    *res.resmore = *lr.resmore;
    break;
  default:
    break;
  }

  sbp->replyref (res);
}

void
authclnt::utab_insert (u_int32_t authno, sfs_authtype at,
		       const sfsauth_dbrec &dbr)
{
  urec_t *u = utab[authno];
  if (u) 
    urecfree (u);
  urec_t *ur = New urec_t (authno, at, dbr);
  utab.insert (ur);
  ulist.insert_head (ur);
}

void
authclnt::sfs_logout (svccb *sbp)
{
  u_int32_t authno = *sbp->template getarg<u_int32_t> ();
  urec_t *u = utab[authno];
  if (u) 
    urecfree (u);
  sfsserv::sfs_logout (sbp);
}

inline bool
sourceok (str source)
{
  for (u_int i = 0; i < source.len (); i++)
    if (source[i] < 0x20 || source[i] >= 0x7f)
      return false;
  return true;
}

void
authclnt::sfsauth_login (sfsauth2_loginres *resp, const sfsauth2_loginarg *lap,
			 bool srp_ok, sfsauth_dbrec *dbp, sfs_authtype *atp)
{
  sfs_autharg2 aa;
  sfsauth2_loginres &res = *resp;
  if (!sourceok (lap->source)) {
    res.set_status (SFSLOGIN_BAD);
    *res.errmsg = "invalid source in login reguest";
    return;
  }
  if (!bytes2xdr (aa, lap->arg.certificate)) {
    res.set_status (SFSLOGIN_BAD);
    *res.errmsg = "cannot unmarshal certificate";
    return;
  }
  if (atp) *atp = aa.type;

  str method;
  switch (aa.type) {
  case SFS_AUTHREQ:
  case SFS_AUTHREQ2:
    method = "public key";
    authreq (&res, lap, &aa, dbp);
    break;
  case SFS_UNIXPWAUTH:
    method = "unix password";
    unixpwauth (&res, lap, &aa, srp_ok, dbp);
    break;
  case SFS_SRPAUTH:
    method = "SRP password";
    if (srp_ok)
      srpauth (&res, lap, &aa, dbp);
    else {
      res.set_status (SFSLOGIN_BAD);
      *res.errmsg = "SRP authentication of client to third party not allowed";
    }
    break;
  default:
    res.set_status (SFSLOGIN_BAD);
    *res.errmsg = strbuf ("unknown login type %d", aa.type);
    break;
  }

  if (res.status == SFSLOGIN_OK && !res.resok->creds.size ())
    res.set_status (SFSLOGIN_BAD);
  if (res.status == SFSLOGIN_OK) {
    if (res.resok->creds[0].type == SFS_UNIXCRED)
      warn << "accepted user " << res.resok->creds[0].unixcred->username
	   << " from " << lap->source
	   << " using " << method << "\n";
    else if (res.resok->creds[0].type == SFS_PKCRED)
      warn << "accepted pubkey " << *res.resok->creds[0].pkhash
	   << " from " << lap->source
	   << " using " << method << "\n";
  }
}

void
authclnt::authreq (sfsauth2_loginres *resp,
		   const sfsauth2_loginarg *lap, const sfs_autharg2 *aap,
		   sfsauth_dbrec *dbrp)
{
  ptr<sfspub> vrfy;
  str logname;
  sfs_msgtype mtype;
  str s;
  if (aap->type == SFS_AUTHREQ) {
    
    const sfs_pubkey &kp = aap->authreq1->usrkey;
    if (!(vrfy = sfscrypt.alloc (kp, SFS_VERIFY))) {
      *resp->errmsg = "Cannot load public Rabin key";
      resp->set_status (SFSLOGIN_BAD);
      return ;
    }
    sfs_signed_authreq authreq;
    str msg;
    if (!vrfy->verify_r (aap->authreq1->signed_req, sizeof (authreq), msg)
	|| !str2xdr (authreq, msg)
	|| (authreq.type != SFS_SIGNED_AUTHREQ && 
	    authreq.type != SFS_SIGNED_AUTHREQ_NOCRED)
	|| authreq.seqno != lap->arg.seqno
	|| authreq.authid != lap->authid) {
      resp->set_status (SFSLOGIN_BAD);
      *resp->errmsg = "bad signature";
      return;
    }
    mtype = authreq.type;
    if (authreq.usrinfo[0]) {
      if (memchr (authreq.usrinfo.base (), 0, authreq.usrinfo.size ()))
	logname = authreq.usrinfo.base ();
      else
	logname.setbuf (authreq.usrinfo.base (), authreq.usrinfo.size ());
    }
  }
  else {
    str e;
    if (aap->sigauth->req.user.len ())
      logname = aap->sigauth->req.user;
    mtype = aap->sigauth->req.type;
    str msg = xdr2str (aap->sigauth->req);
    resp->set_status (SFSLOGIN_BAD);
    if (!(vrfy = sfscrypt.alloc (aap->sigauth->key, SFS_VERIFY))) {
      *resp->errmsg = "cannot load public key";
      return; 
    } else if (!vrfy->verify (aap->sigauth->sig, msg, &e)) {
      *resp->errmsg = str (strbuf ("Bad login: " << e));
      return;
    }
    if (!checkreq (resp, lap, &aap->sigauth->req))
      return;
  }
  for (dbfile *dbp = dbfiles.base (); dbp < dbfiles.lim (); dbp++) {
    ptr<authcursor> ac = dbp->db->open ();

    // XXX - in long form for aiding in debugging
    if (!ac)
      continue;
    if (!ac->find_user_pubkey (*vrfy))
      continue;
    if (ac->ae.type != SFSAUTH_USER)
      continue;
    if (logname) {
      if (dbp->prefix) {
	if (logname != dbp->prefix << "/" << ac->ae.userinfo->name) 
	  continue;
      } else {
	if (logname != ac->ae.userinfo->name)
	  continue;
      }
    }

    if (*vrfy == ac->ae.userinfo->pubkey) {
      if (!setuser (resp, ac->ae, dbp))
	continue;
      setuser_pkhash (resp, vrfy);
    } else if (mtype == SFS_SIGNED_AUTHREQ_NOCRED) {
      resp->set_status (SFSLOGIN_OK);
      resp->resok->creds.setsize (1);
      resp->resok->creds[0].set_type (SFS_NOCRED);
    } else {
      continue;
    }

#if 0
    str source = "";
    if (lap->source.len ())
      source = strbuf () << " from " << lap->source;
    str name = ac->ae.userinfo->name;
    if (dbp->prefix)
      name = dbp->prefix << "/" << name;
    warn << "signed login by " << name << source << "\n";
#endif
    if (dbrp)
      *dbrp = ac->ae;
    return;
  }
  if (mtype != SFS_SIGNED_AUTHREQ_NOCRED) {
    resp->set_status (SFSLOGIN_OK);
    setuser_pkhash (resp, vrfy);
  } else {
    resp->set_status (SFSLOGIN_BAD);
    *resp->errmsg = "bad login";
  }
}

#define ROOT_SU(pw, ok)   (!*(pw) && (ok) && uid && *uid == 0)

void
authclnt::unixpwauth (sfsauth2_loginres *resp, const sfsauth2_loginarg *lap,
		      const sfs_autharg2 *aap, bool root_su_ok,
		      sfsauth_dbrec *dbrp)
{
  dbfile *dbp = dbfiles.base ();
  str name, source;
  ptr<authcursor> ac;
  bool flag = false;

  struct passwd *pe = getpwnam (aap->pwauth->req.user.cstr ());
  const char *pw;
  if (!pe) 
    goto badlogin;
  pw = aap->pwauth->password.cstr ();

  while (dbp < dbfiles.lim ()) {
    if (dbp->allow_unix_pwd) {
      flag = true;
      break;
    }
    dbp++;
  }
  if (!uid)
    goto badlogin;
  if (!flag)
    goto badlogin;

#if HAVE_GETSPNAM
  if (struct spwd *spe = getspnam (aap->pwauth->req.user.cstr ())) {
    if (!ROOT_SU (pw, root_su_ok) &&
	strcmp (spe->sp_pwdp, crypt (pw, spe->sp_pwdp)))
      goto badlogin;
  }
  else
#endif /* HAVE_GETSPNAM */
    if (!ROOT_SU (pw, root_su_ok) &&
	strcmp (pe->pw_passwd, crypt (pw, pe->pw_passwd)))
      goto badlogin;

  ac = dbp->db->open ();
  if (!ac)
    goto badlogin;
  if (!ac->find_user_name (pe->pw_name)) {
    ac->ae.set_type (SFSAUTH_USER);
    ac->ae.userinfo->name = pe->pw_name;
    ac->ae.userinfo->id = pe->pw_uid;
    ac->ae.userinfo->vers = 0;
    ac->ae.userinfo->gid = pe->pw_gid;
  }

  if (!validshell (pe->pw_shell)) {
    resp->set_status (SFSLOGIN_BAD);
    *resp->errmsg = "bad shell";
    goto badlogin;
  }

  if (!setuser (resp, ac->ae, dbp))
    return;
#if 0
  source = "";
  if (lap->source.len ())
    source = strbuf () << " from " << lap->source;
  name = ac->ae.userinfo->name;
  if (dbp->prefix)
    name = dbp->prefix << "/" << name;
  warn << "password login by " << name << source << "\n";
#endif

  if (dbrp)
    *dbrp = ac->ae;
  return;

 badlogin:
  warn << "BAD unix password login attempt for "
       << aap->pwauth->req.user.cstr ()
       << " from " << client_name << "\n";
  if (resp->status != SFSLOGIN_BAD)
    resp->set_status (SFSLOGIN_BAD);
  if (!resp->errmsg->len ())
    *resp->errmsg = "bad login";
}

void
authclnt::srpauth (sfsauth2_loginres *resp, const sfsauth2_loginarg *lap,
		   const sfs_autharg2 *aap, sfsauth_dbrec *dbrp)
{
  if (aap->srpauth->req.authid != lap->authid
      || (!sfsauth_login_srp && (!authid_valid || lap->authid != authid))) {
    srp_ac = NULL;
    resp->set_status (SFSLOGIN_BAD);
    *resp->errmsg = "invalid or missing authid";
    return;
  }
  if (aap->srpauth->msg.size () == 0) {
    sfsauth_dbkey kname (SFSAUTH_DBKEY_NAME);
    *kname.name = aap->srpauth->req.user;
    if (!get_user_cursor (&srp_dbp, &srp_ac, NULL, kname)) {
      srp_ac = NULL;
      resp->set_status (SFSLOGIN_BAD);
      *resp->errmsg = "bad login";
      return;
    }
    srp_seqno = lap->arg.seqno;
    srp_authid = lap->authid;
    resp->set_status (SFSLOGIN_MORE);
    switch (srp.init (resp->resmore.addr (), &aap->srpauth->msg,
		      srp_authid, srp_ac->ae.userinfo->name,
		      srp_ac->ae.userinfo->pwauth)) {
    case SRP_NEXT:
      return;
    default:
      srp_ac = NULL;
      resp->set_status (SFSLOGIN_BAD);
      *resp->errmsg = "bad login";
      return;
    }
  }
  if (!srp_ac || srp_authid != lap->authid || srp_seqno != lap->arg.seqno) {
    resp->set_status (SFSLOGIN_BAD);
    *resp->errmsg = "unknown SRP session";
    return;
  }
  srpmsg msg;
  switch (srp.next (&msg, &aap->srpauth->msg)) {
  case SRP_NEXT:
    resp->set_status (SFSLOGIN_MORE);
    *resp->resmore = msg;
    return;
  case SRP_LAST:
    if (setuser (resp, srp_ac->ae, srp_dbp)) 
      resp->resok->resmore = msg;
    if (dbrp) *dbrp = srp_ac->ae;
   break;
  default:
    warn << "BAD SRP login attempt for "
	 << mkname (srp_dbp, srp_ac->ae.userinfo->name)
	 << " from " << client_name << "\n";
    resp->set_status (SFSLOGIN_BAD);
    *resp->errmsg = "bad login";
    break;
  }
  srp_ac = NULL;
}

str
authclnt::update_srv_keyhalf (const sfsauth_keyhalf &updkh,
			      sfsauth_keyhalf &newkh,
			      const sfsauth_keyhalf &oldkh,
			      bool canclear, urec_t *ur)
{
  bool hasoldkh = false;
  const sfsauth_keyhalf_type &kht = updkh.type;
  if (kht == SFSAUTH_KEYHALF_NONE)
    return NULL;

  newkh.set_type (SFSAUTH_KEYHALF_PRIV);
  u_int okeys = 0;
  if (oldkh.type == SFSAUTH_KEYHALF_PRIV)
    okeys =  oldkh.priv->size ();
  u_int nkeys;
  if (oldkh.type == SFSAUTH_KEYHALF_PRIV && okeys >= 1) {
    hasoldkh = true;
    if (kht == SFSAUTH_KEYHALF_DELTA) {
      nkeys = okeys;
      for (u_int i = 1; i < okeys; i++) 
	(*newkh.priv)[i] = (*oldkh.priv)[i];
    } else {
      nkeys = (okeys == SPRIVK_HISTORY_LEN) ? okeys : okeys + 1;
      newkh.priv->setsize (nkeys);
      for (u_int i = 1; i < nkeys; i++)
	(*newkh.priv)[i] = (*oldkh.priv)[i-1];
    }
  } else {
    nkeys = 1;
    newkh.priv->setsize (1);
  }

  if (kht == SFSAUTH_KEYHALF_DELTA) {
    if (!hasoldkh) 
      return "Cannot apply key delta: no key currently exists!";
    (*newkh.priv)[0] = (*oldkh.priv)[0];
    (*newkh.priv)[0].x += *updkh.delta;
    (*newkh.priv)[0].x %= (*newkh.priv)[0].q;
    sprivk_tab.invalidate (hash_sprivk ((*oldkh.priv)[0]));
    sprivk_tab.bind (hash_sprivk ((*newkh.priv)[0]));
  } else if (kht == SFSAUTH_KEYHALF_PRIV) {
    if (!canclear)
      return "Can only explicitly set server keyhalf on register or signed "
	     "update.";
    if (nkeys == okeys) 
      sprivk_tab.invalidate (hash_sprivk ((*oldkh.priv)[okeys - 1]));
    (*newkh.priv)[0] = (*updkh.priv)[0];
    sprivk_tab.bind (hash_sprivk ((*newkh.priv)[0]));
  }
  ur->kh = newkh;
  
  return NULL;
}
