/* $Id: rexd.C,v 1.27 2002/11/28 16:20:55 kaminsky Exp $ */

/*
 *
 * Copyright (C) 2000-2001 Eric Peterson (ericp@lcs.mit.edu)
 * Copyright (C) 2000-2001 Michael Kaminsky (kaminsky@lcs.mit.edu)
 * Copyright (C) 2000 David Mazieres (dm@uun.org)
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

#include "arpc.h"
#include "rex_prot.h"
#include "crypt.h"
#include "sfsserv.h"
#include "sfscrypt.h"

int execprotect = 1;

#define REXSESS_DEFAULT_PATH "/bin:/sbin:/usr/bin:/usr/sbin:/usr/local/bin:/usr/local/sbin"

ptr<sfspriv> sk;
sfs_servinfo servinfo;
ptr<sfs_servinfo_w> siw;
str newaid;

class rexsess {
  ref<bool> destroyed;
  rex_sesskeydat kscdat;
  rex_sesskeydat kcsdat;
  seqcheck seqstate;
  ptr<axprt_unix> x;
  ptr<aclnt> c;

  static void postfork (const sfsauth_cred *credp);
  void ctlconnect (ref<bool> abort, ref<sfs_sessinfo> si,
		   ref<axprt_stream> xs);
  void seq2sessinfo (u_int64_t seqno, sfs_hash *sidp, sfs_sessinfo *sip);
  void eof () { delete this; }

public:
  sfs_hash sessid;
  ihash_entry<rexsess> link;

  rexsess (const sfsauth_cred *credp, const rexd_spawn_arg *argp,
	   rexd_spawn_res *resp);
  ~rexsess ();
  void attach (svccb *sbp);
};

ihash<sfs_hash, rexsess, &rexsess::sessid, &rexsess::link> sesstab;

rexsess::rexsess (const sfsauth_cred *credp, const rexd_spawn_arg *argp,
		  rexd_spawn_res *resp)
  : destroyed (New refcounted<bool> (false)), seqstate (32)
{
  resp->set_err (SFS_OK);
  rnd.getbytes (resp->resok->kmsg.kcs_share.base (),
		resp->resok->kmsg.kcs_share.size ());
  rnd.getbytes (resp->resok->kmsg.ksc_share.base (),
		resp->resok->kmsg.ksc_share.size ());
  kcsdat.type = SFS_KCS;
  kcsdat.sshare = resp->resok->kmsg.kcs_share;
  kcsdat.cshare = argp->kmsg.kcs_share;
  kscdat.type = SFS_KSC;
  kscdat.sshare = resp->resok->kmsg.ksc_share;
  kscdat.cshare = argp->kmsg.ksc_share;

  seqstate.check (0);
  seq2sessinfo (0, &sessid, NULL);
  sesstab.insert (this);

  vec<str> av;
  av.push_back (newaid);
  if (credp->type == SFS_UNIXCRED)
    av.push_back (strbuf ("-U%d", credp->unixcred->uid));
  else
    av.push_back (strbuf ("-U-2"));
  av.push_back ("-G");
  av.push_back ("--");
  if (argp->command.size ()) {
    av.push_back (fix_exec_path (argp->command[0]));
    for (size_t i = 1; i < argp->command.size (); i++)
      av.push_back (argp->command[i]);
  }
  av.push_back (NULL);

  /* Set up environment just for exec */
  const char *evarstosave[] = {
#ifdef MAINTAINER
    "SFS_RELEASE",
    "SFS_RUNINPLACE",
    "SFS_HASHCOST",
    "SFS_PORT",
    "ASRV_TRACE",
    "ACLNT_TRACE",
#endif /* MAINTAINER */
    NULL
  };

  vec<str> envs;
  for (int v = 0; evarstosave[v]; v++) {
    char *val = getenv (evarstosave[v]);
    if (val)
      envs.push_back (strbuf () << evarstosave[v] << "=" << val);
  }
  envs.push_back (strbuf () << "USER=" << credp->unixcred->username);
  envs.push_back (strbuf () << "LOGNAME=" << credp->unixcred->username);
  envs.push_back (strbuf () << "HOME=" << credp->unixcred->homedir);
  envs.push_back (strbuf () << "SHELL=" << credp->unixcred->shell);
  envs.push_back (strbuf () << "PATH=" REXSESS_DEFAULT_PATH);
#ifdef MAILPATH
  envs.push_back (strbuf () << "MAIL=" MAILPATH "/" << credp->unixcred->username);
#endif

  vec<char *> env;
  for (const str *s = envs.base (), *e = envs.lim (); s < e; s++)
    env.push_back (const_cast<char *> (s->cstr ()));
  env.push_back (NULL);

  x = axprt_unix_aspawnv (newaid, av, 0, wrap (&postfork, credp), env.base ());
  x->allow_recvfd = false;
  c = aclnt::alloc (x, rexctl_prog_1);
  c->seteofcb (wrap (this, &rexsess::eof));
}

rexsess::~rexsess ()
{
  *destroyed = true;
  bzero (&kscdat, sizeof (kscdat));
  bzero (&kcsdat, sizeof (kcsdat));
  sesstab.remove (this);
}

void
rexsess::seq2sessinfo (u_int64_t seqno, sfs_hash *sidp, sfs_sessinfo *sip)
{
  kcsdat.seqno = seqno;
  kscdat.seqno = seqno;

  sfs_sessinfo si;
  si.type = SFS_SESSINFO;
  si.kcs.setsize (sha1::hashsize);
  sha1_hashxdr (si.kcs.base (), kcsdat, true);
  si.ksc.setsize (sha1::hashsize);
  sha1_hashxdr (si.ksc.base (), kscdat, true);

  if (sidp)
    sha1_hashxdr (sidp->base (), si, true);
  if (sip)
    *sip = si;

  bzero (si.kcs.base (), si.kcs.size ());
  bzero (si.ksc.base (), si.ksc.size ());
}

void
rexsess::postfork (const sfsauth_cred *credp)
{
  if (credp->type != SFS_UNIXCRED)
    fatal ("setpriv: invalid credential type %d\n", int (credp->type));

  GETGROUPS_T groups[NGROUPS_MAX];
  size_t ngroups = min<size_t> (credp->unixcred->groups.size () + 1,
				NGROUPS_MAX);
  groups[0] = credp->unixcred->gid;
  for (size_t i = 1; i < ngroups; i++)
    groups[i] = credp->unixcred->groups[i - 1];

  /* No setuid.  This must happen in an execed process, because we are
   * too paranoid about ptrace, signals, core dumps, etc. (given how
   * many private keys we have in memory). */
  if (setgroups (ngroups, groups) < 0)
    fatal ("setgroups: %m\n");
  if (setgid (groups[0]) < 0)
    fatal ("setgid: %m\n");
  if (setsid () < 0) 
    warn ("setsid: %m\n");

  if (chdir (credp->unixcred->homedir) < 0)
    warn << "Could not chdir to home directory " 
	 << credp->unixcred->homedir << ": " << strerror (errno) << "\n";

  /* XXX - need to reduce hard limit on file descriptors. */
}

void
rexsess::attach (svccb *sbp)
{
  // XXX - dynamic_cast is busted in egcs
  axprt_stream *xsp
    = static_cast<axprt_stream *> (sbp->getsrv ()->xprt ().get ());
  ref<axprt_stream> xs (mkref (xsp));

  /* XXX - Note that what we are doing here does not pipeline.  If a
   * cilent sends a REXD_ATTACH RPC followed by another RPC (before
   * getting thre reply from the REXD_ATTACH), we may end up reading
   * both RPC's and discarding the second one (because we pass the
   * file descriptor off to the child process and discard any extra
   * data we have read and buffered). */
  xhinfo::xon (xs, false);
  rexd_attach_arg *argp = sbp->template getarg<rexd_attach_arg> ();

  sfs_hash sid;
  ref<sfs_sessinfo> si = New refcounted<sfs_sessinfo>;
  seq2sessinfo (argp->seqno, &sid, si);

  if (sid == argp->newsessid && seqstate.check (argp->seqno)) {
    sbp->replyref (rexd_attach_res (SFS_OK));
    xs->setwcb (wrap (this, &rexsess::ctlconnect, destroyed, si, xs));
  }
  else {
    bzero (si->kcs.base (), si->kcs.size ());
    bzero (si->ksc.base (), si->ksc.size ());
    warn ("newsessid mistatch\n");
    sbp->replyref (rexd_attach_res (SFS_BADLOGIN));
  }
}

void
rexsess::ctlconnect (ref<bool> abort, ref<sfs_sessinfo> si,
		     ref<axprt_stream> xs)
{
  int fd = xs->reclaim ();
  if (*abort || fd < 0) {
    bzero (si->kcs.base (), si->kcs.size ());
    bzero (si->ksc.base (), si->ksc.size ());
    if (fd >= 0)
      close (fd);
    return;
  }
  x->sendfd (fd);
  // XXX - c->call may leave un-bzeroed copies of session key around
  c->call (REXCTL_CONNECT, si, NULL, aclnt_cb_null);
  bzero (si->kcs.base (), si->kcs.size ());
  bzero (si->ksc.base (), si->ksc.size ());
}

struct rexclient : public sfsserv {
  ptr<asrv> rexs;

  rexclient (ref<axprt_crypt> x)
    : sfsserv (x),
      rexs (asrv::alloc (x, rexd_prog_1, wrap (this, &rexclient::dispatch))) {}

  ptr<sfspriv> doconnect (const sfs_connectarg *ci, sfs_servinfo *si)
  { *si = servinfo; return ::sk; }
  void dispatch (svccb *sbp);
};

void
rexclient::dispatch (svccb *sbp)
{
  if (!sbp) {
    delete this;
    return;
  }

  switch (sbp->proc ()) {
  case REXD_NULL:
    sbp->reply (NULL);
    break;
  case REXD_SPAWN:
    {
      u_int32_t authno = sbp->getaui ();
      if (authno >= credtab.size () || credtab[authno].type != SFS_UNIXCRED) {
	sbp->reject (AUTH_BADCRED);
	break;
      }
      rexd_spawn_res res;
      vNew rexsess (&credtab[authno],
		    sbp->template getarg<rexd_spawn_arg> (), &res);
      sbp->replyref (res);
      break;
    }
  case REXD_ATTACH:
    {
      rexd_attach_arg *argp = sbp->template getarg<rexd_attach_arg> ();
      if (rexsess *sp = sesstab[argp->sessid]) {
	sp->attach (sbp);
	delete this;
      }
      else
	sbp->replyref (rexd_attach_res (SFS_BADLOGIN));
      break;
    }
  default:
    sbp->reject (PROC_UNAVAIL);
    break;
  }
}

void
client_accept (ptr<axprt_crypt> x)
{
  if (!x)
    fatal ("EOF from sfssd\n");
  vNew rexclient (x);
}

static void
loadkey (const char *path)
{
  if (!path)
    path = "sfs_host_key";
  str keyfile = sfsconst_etcfile (path);
  if (!keyfile)
    fatal << path << ": " << strerror (errno) << "\n";
  str key = file2wstr (keyfile);
  if (!key)
    fatal << keyfile << ": " << strerror (errno) << "\n";
  if (!(sk = sfscrypt.alloc_priv (key, SFS_DECRYPT)))
    fatal << "could not decode " << keyfile << "\n";
}

static void
usage ()
{
  fatal << "usage: " << progname << " [-k keyfile]\n";
}

int
main (int argc, char **argv)
{
  const char *keyfile = NULL;
  setprogname (argv[0]);
  sfsconst_init ();

  servinfo.set_sivers (7);
  servinfo.cr7->host.hostname = "";
  servinfo.cr7->host.port = 0;
  servinfo.cr7->release = 7;

  int ch;
  while ((ch = getopt (argc, argv, "k:h:")) != -1)
    switch (ch) {
    case 'k':
      keyfile = optarg;
      break;
    case 'h':
      servinfo.cr7->host.hostname = optarg;
      break;
    default:
      usage ();
    }
  if (optind < argc)
    usage ();

  warn ("version %s, pid %d\n", VERSION, getpid ());
  loadkey (keyfile);

  servinfo.cr7->host.type = SFS_HOSTINFO;
  if (servinfo.cr7->host.hostname == "")
    servinfo.cr7->host.hostname = sfshostname ();
  if (!sk->export_pubkey (&servinfo.cr7->host.pubkey))
    fatal ("could not get pubkey\n");
  servinfo.cr7->prog = REXD_PROG;
  servinfo.cr7->vers = REXD_VERS;
  
  siw = sfs_servinfo_w::alloc (servinfo);

  warn << "serving " << siw->mkpath () << "\n";
  newaid = fix_exec_path ("newaid");
  if (!newaid)
    fatal ("could not find newaid (should be in %s)\n", execdir.cstr ());



  str ptydpath = fix_exec_path ("ptyd");
  if (!ptydpath)
    warn ("could not find ptyd (should be in %s)\n", execdir.cstr ());
  else {
    char *av[2] = { "ptyd", NULL };
    
    int pid = spawn (ptydpath, av);
    
    if (pid < 0)
      warn << ptydpath << ": " << strerror (errno) << "\n";
    else
      warn << "spawning " << ptydpath << "\n";
  }
  mode_t m = umask (0);
  mkdir ("/tmp/.X11-unix", 01777);
  umask (m);

  
  sfssd_slave (wrap (client_accept));
  amain ();
}

