Back to index

courier  0.68.2
starttls.c
Go to the documentation of this file.
00001 /*
00002 ** Copyright 2000-2008 Double Precision, Inc.
00003 ** See COPYING for distribution information.
00004 */
00005 #include      "config.h"
00006 #include      "argparse.h"
00007 #include      "spipe.h"
00008 
00009 #include      "libcouriertls.h"
00010 #include      "tlscache.h"
00011 #include      "rfc1035/rfc1035.h"
00012 #include      "soxwrap/soxwrap.h"
00013 #ifdef  getc
00014 #undef  getc
00015 #endif
00016 #include      <stdio.h>
00017 #include      <string.h>
00018 #include      <stdlib.h>
00019 #include      <ctype.h>
00020 #include      <netdb.h>
00021 #if HAVE_DIRENT_H
00022 #include <dirent.h>
00023 #define NAMLEN(dirent) strlen((dirent)->d_name)
00024 #else
00025 #define dirent direct
00026 #define NAMLEN(dirent) (dirent)->d_namlen
00027 #if HAVE_SYS_NDIR_H
00028 #include <sys/ndir.h>
00029 #endif
00030 #if HAVE_SYS_DIR_H
00031 #include <sys/dir.h>
00032 #endif
00033 #if HAVE_NDIR_H
00034 #include <ndir.h>
00035 #endif
00036 #endif
00037 #if    HAVE_UNISTD_H
00038 #include      <unistd.h>
00039 #endif
00040 #if    HAVE_FCNTL_H
00041 #include      <fcntl.h>
00042 #endif
00043 #include      <errno.h>
00044 #if    HAVE_SYS_TYPES_H
00045 #include      <sys/types.h>
00046 #endif
00047 #if    HAVE_SYS_STAT_H
00048 #include      <sys/stat.h>
00049 #endif
00050 #include      <sys/socket.h>
00051 #include      <arpa/inet.h>
00052 
00053 #if TIME_WITH_SYS_TIME
00054 #include        <sys/time.h>
00055 #include        <time.h>
00056 #else
00057 #if HAVE_SYS_TIME_H
00058 #include        <sys/time.h>
00059 #else
00060 #include        <time.h>
00061 #endif
00062 #endif
00063 #include      <locale.h>
00064 
00065 
00066 /* Command-line options: */
00067 const char *clienthost=0;
00068 const char *clientport=0;
00069 
00070 const char *server=0;
00071 const char *localfd=0;
00072 const char *remotefd=0;
00073 const char *statusfd=0;
00074 const char *tcpd=0;
00075 const char *peer_verify_domain=0;
00076 const char *fdprotocol=0;
00077 static FILE *errfp;
00078 static FILE *statusfp;
00079 
00080 const char *printx509=0;
00081 
00082 static void ssl_errmsg(const char *errmsg, void *dummy)
00083 {
00084        fprintf(errfp, "%s\n", errmsg);
00085 }
00086 
00087 static void nonsslerror(const char *pfix)
00088 {
00089        fprintf(errfp, "%s: %s\n", pfix, strerror(errno));
00090 }
00091 
00092 void docopy(ssl_handle ssl, int sslfd, int stdinfd, int stdoutfd)
00093 {
00094        struct tls_transfer_info transfer_info;
00095 
00096        char from_ssl_buf[BUFSIZ], to_ssl_buf[BUFSIZ];
00097        char *fromptr;
00098        int rc;
00099 
00100        fd_set fdr, fdw;
00101        int    maxfd=sslfd;
00102 
00103        if (fcntl(stdinfd, F_SETFL, O_NONBLOCK)
00104            || fcntl(stdoutfd, F_SETFL, O_NONBLOCK)
00105            )
00106        {
00107               nonsslerror("fcntl");
00108               return;
00109        }
00110 
00111        if (maxfd < stdinfd) maxfd=stdinfd;
00112        if (maxfd < stdoutfd)       maxfd=stdoutfd;
00113 
00114        tls_transfer_init(&transfer_info);
00115 
00116        transfer_info.readptr=fromptr=from_ssl_buf;
00117 
00118        for (;;)
00119        {
00120               if (transfer_info.readptr == fromptr)
00121               {
00122                      transfer_info.readptr=fromptr=from_ssl_buf;
00123                      transfer_info.readleft=sizeof(from_ssl_buf);
00124               }
00125               else
00126                      transfer_info.readleft=0;
00127 
00128               FD_ZERO(&fdr);
00129               FD_ZERO(&fdw);
00130 
00131               rc=tls_transfer(&transfer_info, ssl, sslfd, &fdr, &fdw);
00132 
00133               if (rc == 0)
00134                      continue;
00135               if (rc < 0)
00136                      break;
00137 
00138               if (!tls_inprogress(&transfer_info))
00139               {
00140                      if (transfer_info.readptr > fromptr)
00141                             FD_SET(stdoutfd, &fdw);
00142 
00143                      if (transfer_info.writeleft == 0)
00144                             FD_SET(stdinfd, &fdr);
00145               }
00146 
00147               if (select(maxfd+1, &fdr, &fdw, 0, 0) <= 0)
00148               {
00149                      if (errno != EINTR)
00150                      {
00151                             nonsslerror("select");
00152                             break;
00153                      }
00154                      continue;
00155               }
00156 
00157               if (FD_ISSET(stdoutfd, &fdw) &&
00158                   transfer_info.readptr > fromptr)
00159               {
00160                      rc=write(stdoutfd, fromptr,
00161                              transfer_info.readptr - fromptr);
00162 
00163                      if (rc <= 0)
00164                             break;
00165 
00166                      fromptr += rc;
00167               }
00168 
00169               if (FD_ISSET(stdinfd, &fdr) && transfer_info.writeleft == 0)
00170               {
00171                      rc=read(stdinfd, to_ssl_buf, sizeof(to_ssl_buf));
00172                      if (rc <= 0)
00173                             break;
00174 
00175                      transfer_info.writeptr=to_ssl_buf;
00176                      transfer_info.writeleft=rc;
00177               }
00178        }
00179 
00180        tls_closing(&transfer_info);
00181 
00182        for (;;)
00183        {
00184               FD_ZERO(&fdr);
00185               FD_ZERO(&fdw);
00186 
00187               if (tls_transfer(&transfer_info, ssl, sslfd, &fdr, &fdw) < 0)
00188                      break;
00189 
00190               if (select(maxfd+1, &fdr, &fdw, 0, 0) <= 0)
00191               {
00192                      if (errno != EINTR)
00193                      {
00194                             nonsslerror("select");
00195                             break;
00196                      }
00197                      continue;
00198               }
00199        }
00200 }
00201 
00202 struct dump_capture_subject {
00203        char line[1024];
00204        int line_size;
00205 
00206        int set_subject;
00207        int seen_subject;
00208        int in_subject;
00209        FILE *fp;
00210 };
00211 
00212 static void dump_to_fp(const char *p, int cnt, void *arg)
00213 {
00214        struct dump_capture_subject *dcs=(struct dump_capture_subject *)arg;
00215        char *n, *v;
00216        char namebuf[64];
00217 
00218        if (cnt < 0)
00219               cnt=strlen(p);
00220 
00221        if (dcs->fp && fwrite(p, cnt, 1, dcs->fp) != 1)
00222               ; /* NOOP */
00223 
00224        while (cnt)
00225        {
00226               if (*p != '\n')
00227               {
00228                      if (dcs->line_size < sizeof(dcs->line)-1)
00229                             dcs->line[dcs->line_size++]=*p;
00230 
00231                      ++p;
00232                      --cnt;
00233                      continue;
00234               }
00235               dcs->line[dcs->line_size]=0;
00236               ++p;
00237               --cnt;
00238               dcs->line_size=0;
00239 
00240               if (strncmp(dcs->line, "Subject:", 8) == 0)
00241               {
00242                      if (dcs->seen_subject)
00243                             continue;
00244 
00245                      dcs->seen_subject=1;
00246                      dcs->in_subject=1;
00247                      continue;
00248               }
00249 
00250               if (!dcs->in_subject)
00251                      continue;
00252 
00253               if (dcs->line[0] != ' ')
00254               {
00255                      dcs->in_subject=0;
00256                      continue;
00257               }
00258 
00259               for (n=dcs->line; *n; n++)
00260                      if (*n != ' ')
00261                             break;
00262 
00263               for (v=n; *v; v++)
00264               {
00265                      *v=toupper(*v);
00266                      if (*v == '=')
00267                      {
00268                             *v++=0;
00269                             break;
00270                      }
00271               }
00272 
00273               namebuf[snprintf(namebuf, sizeof(namebuf)-1,
00274                              "TLS_SUBJECT_%s", n)]=0;
00275 
00276               if (dcs->set_subject)
00277                      setenv(namebuf, v, 1);
00278        }
00279 }
00280 
00281 static int verify_connection(ssl_handle ssl, void *dummy)
00282 {
00283        FILE   *printx509_fp=NULL;
00284        int    printx509_fd=0;
00285        char   *buf;
00286 
00287        struct dump_capture_subject dcs;
00288 
00289        memset(&dcs, 0, sizeof(dcs));
00290 
00291        if (printx509)
00292        {
00293               printx509_fd=atoi(printx509);
00294 
00295               printx509_fp=fdopen(printx509_fd, "w");
00296                 if (!printx509_fp)
00297                         nonsslerror("fdopen");
00298        }
00299 
00300        dcs.fp=printx509_fp;
00301 
00302        dcs.set_subject=0;
00303 
00304        if (tls_certificate_verified(ssl))
00305               dcs.set_subject=1;
00306 
00307        tls_dump_connection_info(ssl, server ? 1:0, dump_to_fp, &dcs);
00308 
00309        if (printx509_fp)
00310        {
00311               fclose(printx509_fp);
00312        }
00313 
00314        if (statusfp)
00315        {
00316               fclose(statusfp);
00317               statusfp=NULL;
00318               errfp=stderr;
00319        }
00320 
00321        buf=tls_get_encryption_desc(ssl);
00322 
00323        setenv("TLS_CONNECTED_PROTOCOL",
00324               buf ? buf:"(unknown)", 1);
00325 
00326        if (buf)
00327               free(buf);
00328        return 1;
00329 }
00330 
00331 /* ----------------------------------------------------------------------- */
00332 
00333 static void startclient(int argn, int argc, char **argv, int fd,
00334        int *stdin_fd, int *stdout_fd)
00335 {
00336 pid_t  p;
00337 int    streampipe[2];
00338 
00339        if (localfd)
00340        {
00341               *stdin_fd= *stdout_fd= atoi(localfd);
00342               return;
00343        }
00344 
00345        if (argn >= argc)    return;              /* Interactive */
00346 
00347        if (libmail_streampipe(streampipe))
00348        {
00349               nonsslerror("libmail_streampipe");
00350               exit(1);
00351        }
00352        if ((p=fork()) == -1)
00353        {
00354               nonsslerror("fork");
00355               close(streampipe[0]);
00356               close(streampipe[1]);
00357               exit(1);
00358        }
00359        if (p == 0)
00360        {
00361        char **argvec;
00362        int n;
00363 
00364               close(fd);    /* Child process doesn't need it */
00365               dup2(streampipe[1], 0);
00366               dup2(streampipe[1], 1);
00367               close(streampipe[0]);
00368               close(streampipe[1]);
00369 
00370               argvec=malloc(sizeof(char *)*(argc-argn+1));
00371               if (!argvec)
00372               {
00373                      nonsslerror("malloc");
00374                      exit(1);
00375               }
00376               for (n=0; n<argc-argn; n++)
00377                      argvec[n]=argv[argn+n];
00378               argvec[n]=0;
00379               execvp(argvec[0], argvec);
00380               nonsslerror(argvec[0]);
00381               exit(1);
00382        }
00383        close(streampipe[1]);
00384 
00385        *stdin_fd= *stdout_fd= streampipe[0];
00386 }
00387 
00388 static int connectremote(const char *host, const char *port)
00389 {
00390 int    fd;
00391 
00392 RFC1035_ADDR addr;
00393 int    af;
00394 RFC1035_ADDR *addrs;
00395 unsigned      naddrs, n;
00396 
00397 RFC1035_NETADDR addrbuf;
00398 const struct sockaddr *saddr;
00399 int     saddrlen;
00400 int    port_num;
00401 
00402        port_num=atoi(port);
00403        if (port_num <= 0)
00404        {
00405        struct servent *servent;
00406 
00407               servent=getservbyname(port, "tcp");
00408 
00409               if (!servent)
00410               {
00411                      fprintf(errfp, "%s: invalid port.\n", port);
00412                      return (-1);
00413               }
00414               port_num=servent->s_port;
00415        }
00416        else
00417               port_num=htons(port_num);
00418 
00419        if (rfc1035_aton(host, &addr) == 0) /* An explicit IP addr */
00420        {
00421               if ((addrs=malloc(sizeof(addr))) == 0)
00422               {
00423                      nonsslerror("malloc");
00424                      return (-1);
00425               }
00426               memcpy(addrs, &addr, sizeof(addr));
00427               naddrs=1;
00428        }
00429        else
00430        {
00431               struct rfc1035_res res;
00432               int rc;
00433 
00434               rfc1035_init_resolv(&res);
00435               rc=rfc1035_a(&res, host, &addrs, &naddrs);
00436               rfc1035_destroy_resolv(&res);
00437 
00438               if (rc)
00439               {
00440                      fprintf(errfp, "%s: not found.\n", host);
00441                      return (-1);
00442               }
00443        }
00444 
00445         if ((fd=rfc1035_mksocket(SOCK_STREAM, 0, &af)) < 0)
00446         {
00447                 nonsslerror("socket");
00448                 return (-1);
00449         }
00450 
00451        for (n=0; n<naddrs; n++)
00452        {
00453               if (rfc1035_mkaddress(af, &addrbuf, addrs+n, port_num,
00454                      &saddr, &saddrlen))  continue;
00455 
00456               if (sox_connect(fd, saddr, saddrlen) == 0)
00457                      break;
00458        }
00459        free(addrs);
00460 
00461        if (n >= naddrs)
00462        {
00463               close(fd);
00464               nonsslerror("connect");
00465               return (-1);
00466        }
00467 
00468        return (fd);
00469 }
00470 
00471 static int connect_completed(ssl_handle ssl, int fd)
00472 {
00473        struct tls_transfer_info transfer_info;
00474        tls_transfer_init(&transfer_info);
00475 
00476        while (tls_connecting(ssl))
00477        {
00478               fd_set fdr, fdw;
00479 
00480               FD_ZERO(&fdr);
00481               FD_ZERO(&fdw);
00482               if (tls_transfer(&transfer_info, ssl,
00483                              fd, &fdr, &fdw) < 0)
00484                      return (0);
00485 
00486               if (!tls_connecting(ssl))
00487                      break;
00488 
00489               if (select(fd+1, &fdr, &fdw, 0, 0) <= 0)
00490               {
00491                      if (errno != EINTR)
00492                      {
00493                             nonsslerror("select");
00494                             return (0);
00495                      }
00496               }
00497        }
00498        return (1);
00499 }
00500 
00501 static int dossl(int fd, int argn, int argc, char **argv)
00502 {
00503        ssl_context ctx;
00504        ssl_handle ssl;
00505 
00506        int    stdin_fd, stdout_fd;
00507        struct tls_info info= *tls_get_default_info();
00508 
00509        info.peer_verify_domain=peer_verify_domain;
00510        info.tls_err_msg=ssl_errmsg;
00511        info.connect_callback= &verify_connection;
00512        info.app_data=NULL;
00513 
00514        ctx=tls_create(server ? 1:0, &info);
00515        if (ctx == 0) return (1);
00516 
00517        ssl=tls_connect(ctx, fd);
00518 
00519        if (!ssl)
00520        {
00521               close(fd);
00522               return (1);
00523        }
00524 
00525        if (!connect_completed(ssl, fd))
00526        {
00527               tls_disconnect(ssl, fd);
00528               close(fd);
00529               tls_destroy(ctx);
00530               return 1;
00531        }
00532 
00533        stdin_fd=0;
00534        stdout_fd=1;
00535 
00536        startclient(argn, argc, argv, fd, &stdin_fd, &stdout_fd);
00537 
00538        docopy(ssl, fd, stdin_fd, stdout_fd);
00539 
00540        tls_disconnect(ssl, fd);
00541        close(fd);
00542        tls_destroy(ctx);
00543        return (0);
00544 }
00545 
00546 struct protoreadbuf {
00547        char buffer[512];
00548        char *bufptr;
00549        int bufleft;
00550 
00551        char line[256];
00552 } ;
00553 
00554 #define PRB_INIT(p) ( (p)->bufptr=0, (p)->bufleft=0)
00555 
00556 static char protoread(int fd, struct protoreadbuf *prb)
00557 {
00558        fd_set fds;
00559        struct timeval tv;
00560 
00561        FD_ZERO(&fds);
00562        FD_SET(fd, &fds);
00563 
00564        tv.tv_sec=60;
00565        tv.tv_usec=0;
00566 
00567        if (select(fd+1, &fds, NULL, NULL, &tv) <= 0)
00568        {
00569               nonsslerror("select");
00570               exit(1);
00571        }
00572 
00573        if ( (prb->bufleft=read(fd, prb->buffer, sizeof(prb->buffer))) <= 0)
00574        {
00575               errno=ECONNRESET;
00576               nonsslerror("read");
00577               exit(1);
00578        }
00579 
00580        prb->bufptr= prb->buffer;
00581 
00582        --prb->bufleft;
00583        return (*prb->bufptr++);
00584 }
00585 
00586 #define PRB_GETCH(fd,prb) ( (prb)->bufleft-- > 0 ? *(prb)->bufptr++:\
00587                             protoread( (fd), (prb)))
00588 
00589 static const char *prb_getline(int fd, struct protoreadbuf *prb)
00590 {
00591        int i=0;
00592        char c;
00593 
00594        while ((c=PRB_GETCH(fd, prb)) != '\n')
00595        {
00596               if ( i < sizeof (prb->line)-1)
00597                      prb->line[i++]=c;
00598        }
00599        prb->line[i]=0;
00600        return (prb->line);
00601 }
00602 
00603 static void prb_write(int fd, struct protoreadbuf *prb, const char *p)
00604 {
00605        printf("%s", p);
00606        while (*p)
00607        {
00608               int l=write(fd, p, strlen(p));
00609 
00610               if (l <= 0)
00611               {
00612                      nonsslerror("write");
00613                      exit(1);
00614               }
00615               p += l;
00616        }
00617 }
00618 
00619 static int goodimap(const char *p)
00620 {
00621        if (*p == 'x' && p[1] && isspace((int)(unsigned char)p[1]))
00622               ++p;
00623        else
00624        {
00625               if (*p != '*')
00626                      return (0);
00627               ++p;
00628        }
00629        while (*p && isspace((int)(unsigned char)*p))
00630               ++p;
00631        if (strncasecmp(p, "BAD", 3) == 0)
00632        {
00633               exit(1);
00634        }
00635 
00636        if (strncasecmp(p, "BYE", 3) == 0)
00637        {
00638               exit(1);
00639        }
00640 
00641        if (strncasecmp(p, "NO", 2) == 0)
00642        {
00643               exit(1);
00644        }
00645 
00646        return (strncasecmp(p, "OK", 2) == 0);
00647 }
00648 
00649 static void imap_proto(int fd)
00650 {
00651        struct protoreadbuf prb;
00652        const char *p;
00653 
00654        PRB_INIT(&prb);
00655 
00656        do
00657        {
00658               p=prb_getline(fd, &prb);
00659               printf("%s\n", p);
00660 
00661        } while (!goodimap(p));
00662 
00663        prb_write(fd, &prb, "x STARTTLS\r\n");
00664 
00665        do
00666        {
00667               p=prb_getline(fd, &prb);
00668               printf("%s\n", p);
00669        } while (!goodimap(p));
00670 }
00671 
00672 static void pop3_proto(int fd)
00673 {
00674        struct protoreadbuf prb;
00675        const char *p;
00676 
00677        PRB_INIT(&prb);
00678 
00679        p=prb_getline(fd, &prb);
00680        printf("%s\n", p);
00681 
00682        prb_write(fd, &prb, "STLS\r\n");
00683 
00684        p=prb_getline(fd, &prb);
00685        printf("%s\n", p);
00686 }
00687 
00688 static void smtp_proto(int fd)
00689 {
00690        struct protoreadbuf prb;
00691        const char *p;
00692 
00693        char hostname[1024];
00694 
00695        PRB_INIT(&prb);
00696 
00697        do
00698        {
00699               p=prb_getline(fd, &prb);
00700               printf("%s\n", p);
00701        } while ( ! ( isdigit((int)(unsigned char)p[0]) && 
00702                     isdigit((int)(unsigned char)p[1]) &&
00703                     isdigit((int)(unsigned char)p[2]) &&
00704                     (p[3] == 0 || isspace((int)(unsigned char)p[3]))));
00705        if (strchr("123", *p) == 0)
00706               exit(1);
00707 
00708        hostname[sizeof(hostname)-1]=0;
00709        if (gethostname(hostname, sizeof(hostname)-1) < 0)
00710               strcpy(hostname, "localhost");
00711 
00712        prb_write(fd, &prb, "EHLO ");
00713        prb_write(fd, &prb, hostname);
00714        prb_write(fd, &prb, "\r\n");
00715        do
00716        {
00717               p=prb_getline(fd, &prb);
00718               printf("%s\n", p);
00719        } while ( ! ( isdigit((int)(unsigned char)p[0]) && 
00720                     isdigit((int)(unsigned char)p[1]) &&
00721                     isdigit((int)(unsigned char)p[2]) &&
00722                     (p[3] == 0 || isspace((int)(unsigned char)p[3]))));
00723        if (strchr("123", *p) == 0)
00724               exit(1);
00725 
00726        prb_write(fd, &prb, "STARTTLS\r\n");
00727 
00728        do
00729        {
00730               p=prb_getline(fd, &prb);
00731               printf("%s\n", p);
00732        } while ( ! ( isdigit((int)(unsigned char)p[0]) && 
00733                     isdigit((int)(unsigned char)p[1]) &&
00734                     isdigit((int)(unsigned char)p[2]) &&
00735                     (p[3] == 0 || isspace((int)(unsigned char)p[3]))));
00736        if (strchr("123", *p) == 0)
00737               exit(1);
00738 
00739 }
00740 
00741 int main(int argc, char **argv)
00742 {
00743 int    argn;
00744 int    fd;
00745 static struct args arginfo[] = {
00746        { "host", &clienthost },
00747        { "localfd", &localfd},
00748        { "port", &clientport },
00749        { "printx509", &printx509},
00750        { "remotefd", &remotefd},
00751        { "server", &server},
00752        { "tcpd", &tcpd},
00753        { "verify", &peer_verify_domain},
00754        { "statusfd", &statusfd},
00755        { "protocol", &fdprotocol},
00756        {0}};
00757 void (*protocol_func)(int)=0;
00758 
00759        setlocale(LC_ALL, "");
00760        errfp=stderr;
00761 
00762        argn=argparse(argc, argv, arginfo);
00763 
00764        if (statusfd)
00765               statusfp=fdopen(atoi(statusfd), "w");
00766 
00767        if (statusfp)
00768               errfp=statusfp;
00769 
00770        if (fdprotocol)
00771        {
00772               if (strcmp(fdprotocol, "smtp") == 0)
00773                      protocol_func= &smtp_proto;
00774               else if (strcmp(fdprotocol, "imap") == 0)
00775                      protocol_func= &imap_proto;
00776               else if (strcmp(fdprotocol, "pop3") == 0)
00777                      protocol_func= &pop3_proto;
00778               else
00779               {
00780                      fprintf(stderr, "--protocol=%s - unknown protocol.\n",
00781                             fdprotocol);
00782                      exit(1);
00783               }
00784        }
00785 
00786        if (tcpd)
00787        {
00788               dup2(2, 1);
00789               fd=0;
00790        }
00791        else if (remotefd)
00792               fd=atoi(remotefd);
00793        else if (clienthost && clientport)
00794               fd=connectremote(clienthost, clientport);
00795        else
00796        {
00797               fprintf(errfp, "%s: specify remote location.\n",
00798                      argv[0]);
00799               return (1);
00800        }
00801 
00802        if (fd < 0)   return (1);
00803        if (protocol_func)
00804               (*protocol_func)(fd);
00805 
00806        return (dossl(fd, argn, argc, argv));
00807 }