[178] in 6.033-lab

home help back first fref pref prev next nref lref last post

design proposal and tcpproxy

daemon@ATHENA.MIT.EDU (Benjie Chen)
Wed Mar 1 09:29:42 2000

From: Benjie Chen <benjie@cag.lcs.mit.edu>
Message-Id: <200003011429.JAA25529@amsterdam.lcs.mit.edu>
To: 6.033-lab@MIT.EDU
Date: Wed, 1 Mar 2000 09:29:30 -0500 (EST)
Mime-Version: 1.0
Content-Type: text/plain; charset=us-ascii
Content-Transfer-Encoding: 7bit

Hi gang

A lot of you have not sent me your design proposal yet. If you don't do it I
won't be able to give comments back to you. 

Attached is solution to tcp proxy. If you find bugs, let me know. This code
uses the async library code I sent out earlier.

Benjie

-------------- proxy.c -----------------

#include <stdio.h>
#include <assert.h>
#include <ctype.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <errno.h>
#include <netdb.h>
#include <signal.h>
#include <netinet/in.h>
#include <sys/types.h>
#include <sys/socket.h>
#include "../async/async.h"

#define TIMEOUT		10*1000		/* timeout value */
#define MAX_CONS	(NCON_MAX)/2	/* maximum number of connections */
#define WBUF_SIZE	2048		/* size of each write buffer */
#define WBUF_LIM	64		/* if write buffer has this much room,
					   schedule read cb */
#define dprintf 	if (0) printf

static struct sockaddr_in srvaddr; 	/* server's address */
static int ncons = 0;			/* number of connections */
static int listening = 0;		/* is proxy listening */

struct fcon {
  int fd;	   		/* fd */
  int timer_id;    		/* timer id */
  int read_done;		/* no more reads */
  int write_done;		/* no more writes */
  struct fcon *partner;		/* who's my proxy partner */
  unsigned wbuf_start;		/* wbuf start ptr */
  unsigned wbuf_end;		/* wbuf end ptr */
  char wbuf[WBUF_SIZE];		/* write buffer */
};

static void read_ready(void *arg);
static void write_ready(void *arg);
static void conn_timedout(void *arg);

static inline void
fcon_free (struct fcon *fc)
{
  if (fc->timer_id >= 0) cb_timer_free (fc->timer_id);
  if (fc->fd >= 0) {
    cb_io_free (fc->fd, 1);
    cb_io_free (fc->fd, 0);
    close (fc->fd);
  }
  xfree (fc);
}

static inline void
clear_timer(struct fcon *fc)
{
  if (fc->timer_id >= 0) cb_timer_free(fc->timer_id);
  fc->timer_id = -1;
}

static inline void
set_timer(struct fcon *fc)
{
  if (fc->timer_id < 0) 
    fc->timer_id = cb_timer_add(TIMEOUT, conn_timedout, fc);
}

static void
conn_timedout(void *arg)
{
  struct fcon *fc = (struct fcon *)arg;
  fprintf(stderr,"%d: connection timedout\n", fc->fd);
  fcon_free(fc->partner);
  fcon_free(fc);
  ncons--;
}

static void
write_ready(void *arg)
{
  struct fcon *fc = (struct fcon *)arg;
  int n = 0;
  unsigned end;
  
  dprintf("%d: write ready\n", fc->fd);
  
  assert(!fc->write_done);

  /* clear all timers */
  clear_timer(fc);

  /* write as much as we can from the buffer */
  end = fc->wbuf_end % WBUF_SIZE;
  while(fc->wbuf_start < fc->wbuf_end) {
    unsigned start = fc->wbuf_start % WBUF_SIZE;
    unsigned top = (end > start) ? end : WBUF_SIZE;
    n = write(fc->fd, &fc->wbuf[start], top-start);
    if (n == 0 || (n < 0 && errno != EAGAIN)) { /* EOF or error */
      /* shutdown reading on partner's descriptor since write failed */
      dprintf("%d: shutdown read\n", fc->partner->fd);
      cb_io_free(fc->partner->fd, 0);
      shutdown(fc->partner->fd, 0);
      fc->partner->read_done = 1;
      /* set pos and len to zero. this will cause the descriptor to get
       * shutdown for writing below */
      fc->wbuf_start = fc->wbuf_end = 0; 
      break; 
    } else if (n < 0) /* EAGAIN */ 
      break;
    else  /* wrote something */
      fc->wbuf_start += n;
  }

  if (fc->wbuf_start == fc->wbuf_end) {
    /* caught up, canel write callback */
    dprintf("%d: write cb cancled\n", fc->fd);
    cb_io_free (fc->fd, 1);
    /* partner is done reading, so we are done writing, shutdown write */
    if (fc->partner->read_done) {
      dprintf("%d: shutdown write\n", fc->fd);
      shutdown(fc->fd, 1); 
      fc->write_done = 1;
      /* if we are also done reading, and partner is done writing, then
       * close the connection all together */
      if (fc->read_done && fc->partner->write_done) {
	dprintf("%d: closing connection\n", fc->fd);
        fcon_free(fc->partner);
        fcon_free(fc);
	ncons--;
        return;
      }
    }
  } else
    set_timer(fc);

  /* if there are space on the wbuf, make sure read is scheduled */
  if (fc->wbuf_start+WBUF_SIZE-fc->wbuf_end > WBUF_LIM) {
    if (!fc->partner->read_done)
      cb_io_add (fc->partner->fd, 0, read_ready, fc->partner); 
  }
}

static void
read_ready(void *arg)
{
  struct fcon *fc = (struct fcon *)arg;
  int n = 0;
  unsigned start;

  dprintf("%d: read ready\n", fc->fd);

  assert(!fc->read_done);
  assert(!fc->partner->write_done);

  /* only read as much as the wbuf can accomodate */
  start = fc->partner->wbuf_start % WBUF_SIZE;
  while (fc->partner->wbuf_end - fc->partner->wbuf_start < WBUF_SIZE) {
    unsigned end = fc->partner->wbuf_end % WBUF_SIZE;
    unsigned top = start > end ? start : WBUF_SIZE;
    n = read(fc->fd, &fc->partner->wbuf[end], top-end);
    if (n == 0 || (n < 0 && errno != EAGAIN)) { /* EOF or error */
      /* close descriptor for reading and mark it as such */
      dprintf("%d: shutdown read\n", fc->fd);
      cb_io_free (fc->fd, 0);
      shutdown(fc->fd, 0);
      fc->read_done = 1;
      break;
    }
    else if (n < 0) /* EAGAIN */
      break;
    else { /* read something */
      /* if write timer is not set, set it */
      set_timer(fc->partner);
      fc->partner->wbuf_end += n;
    }
  } 
  
  if (fc->partner->wbuf_end - fc->partner->wbuf_start == WBUF_SIZE)
    /* don't need to read, cause write is backloged */
    cb_io_free (fc->fd, 0);

  /* even if we read nothing, we still need to schedule write_ready so we can
   * close everything on the writing end if wbuf is empty */ 
  cb_io_add (fc->partner->fd, 1, write_ready, fc->partner); 
}

static void
srv_conn_ready(void *arg)
{
  struct fcon *srv_fc = (struct fcon *)arg;
  struct sockaddr peer;
  socklen_t len;
  int err, s;

  cb_timer_free(srv_fc->timer_id);
  srv_fc->timer_id = -1;

  getsockopt(srv_fc->fd, SOL_SOCKET, SO_ERROR, &err, &s);
  if (getpeername(srv_fc->fd, &peer, &len) < 0) {
    if (err) {
      /* there are some problems on connection */
      fcon_free(srv_fc->partner);
      fcon_free(srv_fc);
      ncons--;
      return;
    }
  }
  
  dprintf("%d: connection established to server\n", srv_fc->fd);

  /* connection succeeded - now we just react to both sockets */
  cb_io_free (srv_fc->fd, 1),  
  cb_io_add (srv_fc->fd, 0, read_ready, srv_fc); 
  cb_io_add (srv_fc->partner->fd, 0, read_ready, srv_fc->partner); 
}

static void 
local_accept(void *arg)
{
  struct fcon *fc = (struct fcon *)arg;
  struct fcon *clt_fc;
  struct fcon *srv_fc;
  struct sockaddr_in from;
  struct sockaddr_in to;
  int fromlen = sizeof(from);
  int clts, srvs;

  if ((clts = accept(fc->fd, (struct sockaddr *)&from, &fromlen)) < 0) {
    perror("accept");
    exit(1);
  }

  dprintf("%d: client connection\n", clts);

  if ((srvs = socket (AF_INET, SOCK_STREAM, 0)) < 0) {
    /* can't get socket to server, so terminate client as well */
    close(clts);
    return;
  }
  
  ncons++;
  if (ncons == MAX_CONS) {
    cb_io_free(fc->fd, 0);
    listening = 0;
  }
  
  make_async (clts);
  make_async (srvs);

  clt_fc = xmalloc (sizeof (*clt_fc));
  bzero (clt_fc, sizeof (*clt_fc));
  clt_fc->fd = clts;
  srv_fc = xmalloc (sizeof (*srv_fc));
  bzero (srv_fc, sizeof (*srv_fc));
  srv_fc->fd = srvs;
  clt_fc->partner = srv_fc;
  srv_fc->partner = clt_fc;

  /* now try to make a connection to server */
  to = srvaddr;
  if (connect(srvs, (struct sockaddr *)&to, sizeof(to)) < 0 &&
      errno != EINPROGRESS) {
    /* can't connect to server, so terminate client as well */
    fcon_free(srv_fc);
    fcon_free(clt_fc);
    ncons--;
    return;
  }

  cb_io_add (srvs, 1, srv_conn_ready, srv_fc);
  set_timer(srv_fc);
}

int
main (int argc, char **argv)
{
  char *srv_name;
  struct hostent *h;
  char *e1, *e2;
  int srv_port, local_port;
  struct fcon *fc;
  struct sockaddr_in localaddr;
  int s, v;

  /* parse command line */
  if (argc != 4) {
    fprintf(stderr,"usage: %s dst_host dst_port local_port\n", argv[0]);
    exit(1);
  }
  srv_port = strtol(argv[2], &e1, 10);
  local_port = strtol(argv[3], &e2, 10);
  if ((e1 != NULL && *e1 != '\0') || (e2 != NULL && *e2 != '\0')) {
    fprintf(stderr,"usage: %s dst_host dst_port local_port\n", argv[0]);
    exit(1);
  }
  srv_name = argv[1];

  /* lookup server name */
  h = gethostbyname (srv_name);
  if (!h) {
    fprintf (stderr, "%s: hostname lookup failed\n", srv_name);
    exit(1);
  }

  /* create socket */
  if ((s = socket(PF_INET, SOCK_STREAM, 0)) < 0) {
    perror("socket");
    exit(1);
  }
 
  /* bind to local port */
  v = 1; setsockopt(s, SOL_SOCKET, SO_REUSEADDR, (char*)&v, sizeof(v));
  bzero (&localaddr, sizeof(localaddr));
  localaddr.sin_family = AF_INET;
  localaddr.sin_addr.s_addr = htonl(INADDR_ANY); /* anyone can connect */
  localaddr.sin_port = htons(local_port);

  if (bind(s, (struct sockaddr *)&localaddr, sizeof(localaddr))< 0) {
    perror("bind");
    exit(1);
  }

  /* listen to 5 requests at a time */
  if (listen (s, 5) < 0) {
    perror("listen");
    exit(1);
  }
 
  /* make descriptor asynchronous */
  make_async (s);

  /* create callback for accept */
  fc = xmalloc (sizeof (*fc));
  bzero (fc, sizeof (*fc));
  fc->fd = s;
  fc->timer_id = -1;
 
  /* save destination host information */
  bzero (&srvaddr, sizeof(srvaddr));
  srvaddr.sin_family = AF_INET;
  srvaddr.sin_port = htons (srv_port);
  srvaddr.sin_addr = *(struct in_addr *) h->h_addr;
  
  /* writing to an unconnected socket will cause a process to receive a
   * SIGPIPE signal. we don't want to die if this happens, so we ignore
   * SIGPIPE.  */
  signal (SIGPIPE, SIG_IGN);

  /* main loop */
  while (1) {
    if (!listening && ncons < MAX_CONS) { 
      cb_io_add (fc->fd, 0, local_accept, fc);
      listening = 1;
    }
    cb_check();
  }
  
  exit (0);
}


home help back first fref pref prev next nref lref last post