/*
    pmacct (Promiscuous mode IP Accounting package)
    pmacct is Copyright (C) 2003-2005 by Paolo Lucente
*/

/*
    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*/

#define __PMACCT_PLAYER_C

/* includes */
#include "pmacct.h"
#include "pmacct-data.h"
#include "sql_common.h"
#include "pgsql_plugin.h"
#include "util.h"

#define ARGS "df:o:n:sthiP:T:U:D:H:"

struct logfile_header lh;
char sql_pwd[SRVBUFLEN];
int re = 0, we = 0;
int debug = 0;
int sql_dont_try_update = 0;
char *sql_table, *sql_user, *sql_db;
char timebuf[SRVBUFLEN];

void usage(char *prog)
{
  printf("%s\n", PMPGPLAY_USAGE_HEADER);
  printf("Usage: %s -f [filename]\n\n", prog);
  printf("Available options:\n");
  printf("  -d\tenable debug\n");
  printf("  -f\t[filename]\n\tplay specified file\n");
  printf("  -o\t[element]\n\tplay file starting at specified offset element\n");
  printf("  -n\t[num]\n\tnumbers of elements to play\n");
  printf("  -t\ttest only; don't actually write to the DB\n");
  printf("  -P\t[password]\n\tconnect to SQL server using the specified password\n");
  printf("  -U\t[user]\n\tuse the specified user when connecting to SQL server\n");
  printf("  -H\t[host]\n\tconnect to SQL server listening at specified hostname\n");
  printf("  -D\t[DB]\n\tuse the specified SQL database\n");
  printf("  -T\t[table]\n\tuse the specified SQL table\n");
  printf("  -i\tdon't try update, use insert only.\n");
  printf("\n");
  printf("For suggestions, critics, bugs, contact me: %s.\n", MANTAINER);
}

void print_header()
{
  printf("NUM       ");
  printf("ID     ");
#if defined (HAVE_L2)
  printf("SRC MAC            ");
  printf("DST MAC            ");
  printf("VLAN   ");
#endif
#if defined ENABLE_IPV6
  printf("SRC IP                                         ");
  printf("DST IP                                         ");
#else
  printf("SRC IP           ");
  printf("DST IP           ");
#endif
  printf("SRC PORT  ");
  printf("DST PORT  ");
  printf("PROTOCOL    ");
  printf("TOS    ");
  printf("PACKETS     ");
  printf("FLOWS       ");
  printf("BYTES       ");
  printf("BASETIME\n");
}

void print_data(struct db_cache *data, u_int32_t wtc, int num)
{
  struct tm *lt;
  char src_mac[17], dst_mac[17], src_host[INET6_ADDRSTRLEN], dst_host[INET6_ADDRSTRLEN];
  int j;

  printf("%-8d  ", num);
  printf("%-5d  ", data->id);
#if defined (HAVE_L2)
  etheraddr_string(data->eth_shost, src_mac);
  printf("%-17s  ", src_mac);
  etheraddr_string(data->eth_dhost, dst_mac);
  printf("%-17s  ", dst_mac);
  printf("%-5d  ", data->vlan_id);
#endif
#if defined ENABLE_IPV6
  if (wtc & (COUNT_SRC_AS|COUNT_SUM_AS)) printf("%-45d  ", ntohl(data->src_ip.address.ipv4.s_addr));
  else {
    addr_to_str(src_host, &data->src_ip);
    printf("%-45s  ", src_host);
  }
  if (wtc & COUNT_DST_AS) printf("%-45d  ", ntohl(data->dst_ip.address.ipv4.s_addr));
  else {
    addr_to_str(dst_host, &data->dst_ip);
    printf("%-45s  ", dst_host);
  }
#else
  if (wtc & (COUNT_SRC_AS|COUNT_SUM_AS)) printf("%-15d  ", ntohl(data->src_ip.address.ipv4.s_addr));
  else {
    addr_to_str(src_host, &data->src_ip);
    printf("%-15s  ", src_host);
  }
  if (wtc & COUNT_DST_AS) printf("%-15d  ", ntohl(data->dst_ip.address.ipv4.s_addr));
  else {
    addr_to_str(dst_host, &data->dst_ip);
    printf("%-15s  ", dst_host);
  }
#endif
  printf("%-5d     ", data->src_port);
  printf("%-5d     ", data->dst_port);
  printf("%-10s  ", _protocols[data->proto].name);
  printf("%-3d    ", data->tos);
  printf("%-10u  ", data->packet_counter);
  printf("%-10u  ", data->flows_counter);
  printf("%-10u  ", data->bytes_counter);
  if (lh.sql_history) {
    lt = localtime(&data->basetime);
    strftime(timebuf, SRVBUFLEN, "%Y-%m-%d %H:%M:%S" , lt);
    printf("%s\n", timebuf);
  }
  else printf("0\n");
}

int main(int argc, char **argv)
{
  struct insert_data idata;
  PGresult *ret;
  FILE *f;
  unsigned char fbuf[SRVBUFLEN];
  char logfile[SRVBUFLEN];
  char default_pwd[] = "arealsmartpwd";
  int have_pwd = 0, have_logfile = 0, n;
  int result = 0, position = 0, howmany = 0; 
  int do_nothing = 0;
  char *cl_sql_host = NULL, *cl_sql_user = NULL, *cl_sql_db = NULL, *cl_sql_table = NULL;

  char *sql_host;

  struct template_entry *teptr;
  int tot_size = 0, cnt = 0;
  u_char *te;

  struct template_header th;
  struct db_cache data;

  /* getopt() stuff */
  extern char *optarg;
  extern int optind, opterr, optopt;
  int errflag = 0, cp;

  /* signal handling */
  signal(SIGINT, PG_exit_gracefully);

  memset(&idata, 0, sizeof(idata));
  memset(sql_data, 0, sizeof(sql_data));
  memset(update_clause, 0, sizeof(update_clause));
  memset(insert_clause, 0, sizeof(insert_clause));
  memset(lock_clause, 0, sizeof(lock_clause));
  memset(where, 0, sizeof(where));
  memset(values, 0, sizeof(values));
  memset(&data, 0, sizeof(data));
  memset(timebuf, 0, sizeof(timebuf));

  pp_size = sizeof(struct db_cache);

  while (!errflag && ((cp = getopt(argc, argv, ARGS)) != -1)) {
    switch (cp) {
    case 'd':
      debug = TRUE;
      break;
    case 'f':
      strlcpy(logfile, optarg, sizeof(logfile));
      have_logfile = TRUE;
      break;
    case 'o':
      position = atoi(optarg);
      if (!position) {
	printf("ERROR: invalid offset. Exiting.\n");
	exit(1);
      }
      break;
    case 'n':
      howmany = atoi(optarg);
      if (!howmany) {
        printf("ERROR: invalid number of elements. Exiting.\n");
        exit(1);
      }
      break;
    case 't':
      do_nothing = TRUE;
      break;
    case 'i':
      sql_dont_try_update = TRUE;
      break;
    case 'P':
      strlcpy(sql_pwd, optarg, sizeof(sql_pwd));
      have_pwd = TRUE;
      break;
    case 'U':
      cl_sql_user = malloc(SRVBUFLEN);
      memset(cl_sql_user, 0, SRVBUFLEN);
      strlcpy(cl_sql_user, optarg, SRVBUFLEN);
      break;
    case 'D':
      cl_sql_db = malloc(SRVBUFLEN);
      memset(cl_sql_db, 0, SRVBUFLEN);
      strlcpy(cl_sql_db, optarg, SRVBUFLEN);
      break;
    case 'H':
      cl_sql_host = malloc(SRVBUFLEN);
      memset(cl_sql_host, 0, SRVBUFLEN);
      strlcpy(cl_sql_host, optarg, SRVBUFLEN);
      break;
    case 'T':
      cl_sql_table = malloc(SRVBUFLEN);
      memset(cl_sql_table, 0, SRVBUFLEN);
      strlcpy(cl_sql_table, optarg, SRVBUFLEN);
      break;
    case 'h':
      usage(argv[0]);
      exit(0);
      break;
    default:
      usage(argv[0]);
      exit(1);
    }
  }

  /* searching for user supplied values */ 
  if (!howmany) howmany = -1;
  if (!have_pwd) memcpy(sql_pwd, default_pwd, sizeof(default_pwd));
  if (!have_logfile) {
    usage(argv[0]);
    printf("\nERROR: missing logfile (-f)\nExiting...\n");
    exit(1);
  }

  f = fopen(logfile, "r");
  if (!f) {
    printf("ERROR: %s does not exists\nExiting...\n", logfile);
    exit(1);
  }

  fread(&lh, sizeof(lh), 1, f);
  lh.sql_table_version = ntohs(lh.sql_table_version);
  lh.sql_optimize_clauses = ntohs(lh.sql_optimize_clauses);
  lh.sql_history = ntohs(lh.sql_history);
  lh.what_to_count = ntohl(lh.what_to_count);
  lh.magic = ntohl(lh.magic);

  if (lh.magic == MAGIC) {
    if (debug) printf("OK: Valid logfile header read.\n");
    printf("sql_db: %s\n", lh.sql_db); 
    printf("sql_table: %s\n", lh.sql_table);
    printf("sql_user: %s\n", lh.sql_user);
    printf("sql_host: %s\n", lh.sql_host);
    if (cl_sql_db||cl_sql_table||cl_sql_user||cl_sql_host)
      printf("OK: Overrided by commandline options:\n");
    if (cl_sql_db) printf("sql_db: %s\n", cl_sql_db);
    if (cl_sql_table) printf("sql_table: %s\n", cl_sql_table);
    if (cl_sql_user) printf("sql_user: %s\n", cl_sql_user);
    if (cl_sql_host) printf("sql_host: %s\n", cl_sql_host);
  }
  else {
    printf("ERROR: Invalid magic number. Exiting.\n");
    exit(1);
  }

  /* binding SQL stuff */
  if (cl_sql_db) sql_db = cl_sql_db;
  else sql_db = lh.sql_db;
  if (cl_sql_table) sql_table = cl_sql_table;
  else sql_table = lh.sql_table;
  if (cl_sql_user) sql_user = cl_sql_user;
  else sql_user = lh.sql_user;
  if (cl_sql_host) sql_host = cl_sql_host;
  else sql_host = lh.sql_host;

  fread(&th, sizeof(th), 1, f);
  th.magic = ntohl(th.magic);
  th.num = ntohs(th.num);
  th.sz = ntohs(th.sz);

  if (th.magic == TH_MAGIC) {
    if (debug) printf("OK: Valid template header read.\n");
    if (th.num > N_PRIMITIVES) {
      printf("ERROR: maximum number of primitives exceeded. Exiting.\n");
      exit(1);
    }
    te = malloc(th.num*sizeof(struct template_entry));
    memset(te, 0, th.num*sizeof(struct template_entry));
    fread(te, th.num*sizeof(struct template_entry), 1, f);
  }
  else {
    if (debug) printf("ERROR: no template header found.\n");
    exit(1);
  }

  /* checking template */
  if (th.sz >= sizeof(fbuf)) {
    printf("ERROR: Objects are too big. Exiting.\n");
    exit(1);
  }
  teptr = (struct template_entry *) te;
  for (tot_size = 0, cnt = 0; cnt < th.num; cnt++, teptr++)
    tot_size += teptr->size;
  if (tot_size != th.sz) {
    printf("ERROR: malformed template header. Size mismatch. Exiting.\n");
    exit(1);
  }
  TPL_check_sizes(&th, &data, te);
  
  if (!do_nothing) {
    PG_compose_conn_string(&p, sql_host);
    if (!PG_DB_Connect2(&p)) {
      printf("ALERT: PG_DB_Connect2(): PGSQL daemon failed.\n");
      exit(1);
    }
  }
  else {
    if (debug) print_header();
  }

  /* composing the proper (filled with primitives used during
     the current execution) SQL strings */
  idata.num_primitives = PG_compose_static_queries();
  idata.now = time(NULL);

  /* handling offset */ 
  if (position) n = fseek(f, (th.sz*position), SEEK_CUR);

  /* handling single or iterative request */
  if (!do_nothing) ret = PQexec(p.desc, lock_clause);
  while(!feof(f)) {
    if (!howmany) break;
    else if (howmany > 0) howmany--;

    memset(fbuf, 0, th.sz);
    n = fread(fbuf, th.sz, 1, f);
    if (n) {
      re++;
      TPL_pop(fbuf, &data, &th, te);

      if (!do_nothing) result = PG_cache_dbop(p.desc, &data, &idata);
      else {
        if (debug) print_data(&data, lh.what_to_count, (position+re));
      }

      if (!result) we++;
      if (re != we) printf("WARN: unable to write element %u.\n", re);
    }
  }

  if (!do_nothing) {
    ret = PQexec(p.desc, "COMMIT");
    if (PQresultStatus(ret) != PGRES_COMMAND_OK) {
      we = 0; /* if we fail to commit, no elements will be written */
      PQclear(ret);
    }
    printf("\nOK: written [%u/%u] elements.\n", we, re);
  }
  else printf("OK: read [%u] elements.\n", re);
  PQfinish(p.desc);
  fclose(f);

  return 0;
}

int PG_cache_dbop(PGconn *db_desc, struct db_cache *cache_elem, struct insert_data *idata)
{
  PGresult *ret;
  char *ptr_values, *ptr_where, *err_string;
  int num=0, have_flows=0;

  if (lh.what_to_count & COUNT_FLOWS) have_flows = TRUE;

  /* constructing SQL query */
  ptr_where = where_clause;
  ptr_values = values_clause;
  while (num < idata->num_primitives) {
    (*where[num].handler)(cache_elem, idata, num, &ptr_values, &ptr_where);
    num++;
  }
  if (have_flows) snprintf(sql_data, sizeof(sql_data), update_clause, cache_elem->packet_counter, cache_elem->bytes_counter, cache_elem->flows_counter, time(NULL));
  else snprintf(sql_data, sizeof(sql_data), update_clause, cache_elem->packet_counter, cache_elem->bytes_counter, time(NULL));
  strncat(sql_data, where_clause, SPACELEFT(sql_data));

  /* sending UPDATE query */
  if (!sql_dont_try_update) {
    ret = PQexec(db_desc, sql_data);
    if (PQresultStatus(ret) != PGRES_COMMAND_OK) {
      err_string = PQresultErrorMessage(ret);
      PQclear(ret);
      printf("FAILED query follows:\n%s\n", sql_data);
      printf("%s\n", err_string);
      return TRUE;
    }
    PQclear(ret);
  }

  if (sql_dont_try_update || (!PG_affected_rows(ret))) {
    /* UPDATE failed, trying with an INSERT query */
    strncpy(sql_data, insert_clause, sizeof(sql_data));
    if (have_flows) snprintf(ptr_values, SPACELEFT(values_clause), ", %u, %lu, %u)", cache_elem->packet_counter, cache_elem->bytes_counter, cache_elem->flows_counter);
    else snprintf(ptr_values, SPACELEFT(values_clause), ", %u, %lu)", cache_elem->packet_counter, cache_elem->bytes_counter);
    strncat(sql_data, values_clause, SPACELEFT(sql_data));

    ret = PQexec(db_desc, sql_data);
    if (PQresultStatus(ret) != PGRES_COMMAND_OK) {
      err_string = PQresultErrorMessage(ret);
      PQclear(ret);
      printf("FAILED query follows:\n%s\n", sql_data);
      printf("%s\n", err_string);
      return TRUE;
    }
    PQclear(ret);
  }

  if (debug) printf("%s\n\n", sql_data);

  return FALSE;
}

int PG_evaluate_history(int primitive)
{
  if (lh.sql_history) {
    if (primitive) {
      strncat(insert_clause, ", ", SPACELEFT(insert_clause));
      strncat(values[primitive].string, ", ", sizeof(values[primitive].string));
      strncat(where[primitive].string, " AND ", sizeof(where[primitive].string));
    }
    strncat(where[primitive].string, "ABSTIME(%u)::Timestamp::Timestamp without time zone = ", SPACELEFT(where[primitive].string));
    strncat(where[primitive].string, "stamp_inserted", SPACELEFT(where[primitive].string));

    strncat(insert_clause, "stamp_updated, stamp_inserted", SPACELEFT(insert_clause));
    strncat(values[primitive].string, "ABSTIME(%u)::Timestamp, ABSTIME(%u)::Timestamp", SPACELEFT(values[primitive].string));

    where[primitive].type = values[primitive].type = TIMESTAMP;
    values[primitive].handler = where[primitive].handler = count_timestamp_handler;
    primitive++;
  }

  return primitive;
}

int PG_evaluate_primitives(int primitive)
{
  u_int32_t what_to_count = 0, fakes = 0;
  short int assume_custom_table = FALSE;

  if (lh.sql_optimize_clauses) {
    what_to_count = lh.what_to_count;
    assume_custom_table = TRUE;
  }
  else {
    /* we are requested to avoid optimization; then we'll construct an
       all-true "what to count" bitmap. */

    if (lh.what_to_count & COUNT_SRC_MAC) what_to_count |= COUNT_SRC_MAC;
    else fakes |= FAKE_SRC_MAC;
    if (lh.what_to_count & COUNT_DST_MAC) what_to_count |= COUNT_DST_MAC;
    else fakes |= FAKE_DST_MAC;

    if (lh.what_to_count & (COUNT_SRC_HOST|COUNT_SRC_NET)) what_to_count |= COUNT_SRC_HOST;
    else if (lh.what_to_count & COUNT_SRC_AS) what_to_count |= COUNT_SRC_AS;
    else if (lh.what_to_count & COUNT_SUM_HOST) what_to_count |= COUNT_SUM_HOST;
    else if (lh.what_to_count & COUNT_SUM_NET) what_to_count |= COUNT_SUM_NET;
    else if (lh.what_to_count & COUNT_SUM_AS) what_to_count |= COUNT_SUM_AS;
    else {
      if (lh.what_to_count & COUNT_DST_AS) what_to_count |= COUNT_SRC_AS;
      else fakes |= FAKE_SRC_HOST;
    }

    if (lh.what_to_count & (COUNT_DST_HOST|COUNT_DST_NET)) what_to_count |= COUNT_DST_HOST;
    else if (lh.what_to_count & COUNT_DST_AS) what_to_count |= COUNT_DST_AS;
    else {
      if (lh.what_to_count & (COUNT_SRC_AS|COUNT_SUM_AS)) what_to_count |= COUNT_DST_AS;
      else fakes |= FAKE_DST_HOST;
    }

    if (lh.what_to_count & COUNT_SUM_PORT) what_to_count |= COUNT_SUM_PORT;
    if (lh.what_to_count & COUNT_SUM_MAC) what_to_count |= COUNT_SUM_MAC;

    what_to_count |= COUNT_SRC_PORT|COUNT_DST_PORT|COUNT_IP_PROTO|COUNT_ID|COUNT_VLAN|COUNT_IP_TOS;
  }

  /* 1st part: arranging pointers to an opaque structure and 
     composing the static selection (WHERE) string */

#if defined (HAVE_L2)
  if (what_to_count & (COUNT_SRC_MAC|COUNT_SUM_MAC)) {
    if (primitive) {
      strncat(insert_clause, ", ", SPACELEFT(insert_clause));
      strncat(values[primitive].string, ", ", sizeof(values[primitive].string));
      strncat(where[primitive].string, " AND ", sizeof(where[primitive].string));
    }
    strncat(insert_clause, "mac_src", SPACELEFT(insert_clause));
    strncat(values[primitive].string, "\'%s\'", SPACELEFT(values[primitive].string));
    strncat(where[primitive].string, "mac_src=\'%s\'", SPACELEFT(where[primitive].string));
    values[primitive].type = where[primitive].type = COUNT_SRC_MAC;
    values[primitive].handler = where[primitive].handler = count_src_mac_handler;
    primitive++;
  }

  if (what_to_count & COUNT_DST_MAC) {
    if (primitive) {
      strncat(insert_clause, ", ", SPACELEFT(insert_clause));
      strncat(values[primitive].string, ", ", sizeof(values[primitive].string));
      strncat(where[primitive].string, " AND ", sizeof(where[primitive].string));
    }
    strncat(insert_clause, "mac_dst", SPACELEFT(insert_clause));
    strncat(values[primitive].string, "\'%s\'", SPACELEFT(values[primitive].string));
    strncat(where[primitive].string, "mac_dst=\'%s\'", SPACELEFT(where[primitive].string));
    values[primitive].type = where[primitive].type = COUNT_DST_MAC;
    values[primitive].handler = where[primitive].handler = count_dst_mac_handler;
    primitive++;
  }

  if (what_to_count & COUNT_VLAN) {
    int count_it = FALSE;

    if ((lh.sql_table_version < 2) && !assume_custom_table) {
      if (lh.what_to_count & COUNT_VLAN) {
        printf("ERROR: The use of VLAN accounting requires SQL table v2. Exiting.\n");
        exit(1);
      }
      else what_to_count ^= COUNT_VLAN;
    }
    else count_it = TRUE;

    if (count_it) {
      if (primitive) {
        strncat(insert_clause, ", ", SPACELEFT(insert_clause));
        strncat(values[primitive].string, ", ", sizeof(values[primitive].string));
        strncat(where[primitive].string, " AND ", sizeof(where[primitive].string));
      }
      strncat(insert_clause, "vlan", SPACELEFT(insert_clause));
      strncat(values[primitive].string, "%u", SPACELEFT(values[primitive].string));
      strncat(where[primitive].string, "vlan=%u", SPACELEFT(where[primitive].string));
      values[primitive].type = where[primitive].type = COUNT_VLAN;
      values[primitive].handler = where[primitive].handler = count_vlan_handler;
      primitive++;
    }
  }
#endif

  if (what_to_count & (COUNT_SRC_HOST|COUNT_SRC_NET|COUNT_SUM_HOST|COUNT_SUM_NET)) {
    if (primitive) {
      strncat(insert_clause, ", ", SPACELEFT(insert_clause));
      strncat(values[primitive].string, ", ", sizeof(values[primitive].string));
      strncat(where[primitive].string, " AND ", sizeof(where[primitive].string));
    }
    strncat(insert_clause, "ip_src", SPACELEFT(insert_clause));
    strncat(values[primitive].string, "\'%s\'", SPACELEFT(values[primitive].string));
    strncat(where[primitive].string, "ip_src=\'%s\'", SPACELEFT(where[primitive].string));
    values[primitive].type = where[primitive].type = COUNT_SRC_HOST;
    values[primitive].handler = where[primitive].handler = count_src_host_handler;
    primitive++;
  }

  if (what_to_count & (COUNT_DST_HOST|COUNT_DST_NET)) {
    if (primitive) {
      strncat(insert_clause, ", ", SPACELEFT(insert_clause));
      strncat(values[primitive].string, ", ", sizeof(values[primitive].string));
      strncat(where[primitive].string, " AND ", sizeof(where[primitive].string));
    }
    strncat(insert_clause, "ip_dst", SPACELEFT(insert_clause));
    strncat(values[primitive].string, "\'%s\'", SPACELEFT(values[primitive].string));
    strncat(where[primitive].string, "ip_dst=\'%s\'", SPACELEFT(where[primitive].string));
    values[primitive].type = where[primitive].type = COUNT_DST_HOST;
    values[primitive].handler = where[primitive].handler = count_dst_host_handler;
    primitive++;
  }

  if (what_to_count & (COUNT_SRC_AS|COUNT_SUM_AS)) {
    if (primitive) {
      strncat(insert_clause, ", ", SPACELEFT(insert_clause));
      strncat(values[primitive].string, ", ", sizeof(values[primitive].string));
      strncat(where[primitive].string, " AND ", sizeof(where[primitive].string));
    }
    strncat(insert_clause, "ip_src", SPACELEFT(insert_clause));
    strncat(values[primitive].string, "\'%u\'", SPACELEFT(values[primitive].string));
    strncat(where[primitive].string, "ip_src=\'%u\'", SPACELEFT(where[primitive].string));
    values[primitive].type = where[primitive].type = COUNT_SRC_AS;
    values[primitive].handler = where[primitive].handler = count_src_as_handler;
    primitive++;
  }

  if (what_to_count & COUNT_DST_AS) {
    if (primitive) {
      strncat(insert_clause, ", ", SPACELEFT(insert_clause));
      strncat(values[primitive].string, ", ", sizeof(values[primitive].string));
      strncat(where[primitive].string, " AND ", sizeof(where[primitive].string));
    }
    strncat(insert_clause, "ip_dst", SPACELEFT(insert_clause));
    strncat(values[primitive].string, "\'%u\'", SPACELEFT(values[primitive].string));
    strncat(where[primitive].string, "ip_dst=\'%u\'", SPACELEFT(where[primitive].string));
    values[primitive].type = where[primitive].type = COUNT_DST_AS;
    values[primitive].handler = where[primitive].handler = count_dst_as_handler;
    primitive++;
  }

  if (what_to_count & (COUNT_SRC_PORT|COUNT_SUM_PORT)) {
    if (primitive) {
      strncat(insert_clause, ", ", SPACELEFT(insert_clause));
      strncat(values[primitive].string, ", ", sizeof(values[primitive].string));
      strncat(where[primitive].string, " AND ", sizeof(where[primitive].string));
    }
    strncat(insert_clause, "port_src", SPACELEFT(insert_clause));
    strncat(values[primitive].string, "%u", SPACELEFT(values[primitive].string));
    strncat(where[primitive].string, "port_src=%u", SPACELEFT(where[primitive].string));
    values[primitive].type = where[primitive].type = COUNT_SRC_PORT;
    values[primitive].handler = where[primitive].handler = count_src_port_handler;
    primitive++;
  }

  if (what_to_count & COUNT_DST_PORT) {
    if (primitive) {
      strncat(insert_clause, ", ", SPACELEFT(insert_clause));
      strncat(values[primitive].string, ", ", sizeof(values[primitive].string));
      strncat(where[primitive].string, " AND ", sizeof(where[primitive].string));
    }
    strncat(insert_clause, "port_dst", SPACELEFT(insert_clause));
    strncat(values[primitive].string, "%u", SPACELEFT(values[primitive].string));
    strncat(where[primitive].string, "port_dst=%u", SPACELEFT(where[primitive].string));
    values[primitive].type = where[primitive].type = COUNT_DST_PORT;
    values[primitive].handler = where[primitive].handler = count_dst_port_handler;
    primitive++;
  }

  if (what_to_count & COUNT_IP_TOS) {
    int count_it = FALSE;

    if ((lh.sql_table_version < 3) && !assume_custom_table) {
      if (lh.what_to_count & COUNT_IP_TOS) {
	printf("ERROR: The use of ToS/DSCP accounting requires SQL table v3. Exiting.\n");
	exit(1);
      }
      else what_to_count ^= COUNT_IP_TOS;
    }
    else count_it = TRUE;

    if (count_it) {
      if (primitive) {
	strncat(insert_clause, ", ", SPACELEFT(insert_clause));
	strncat(values[primitive].string, ", ", sizeof(values[primitive].string));
	strncat(where[primitive].string, " AND ", sizeof(where[primitive].string));
      }
      strncat(insert_clause, "tos", SPACELEFT(insert_clause));
      strncat(values[primitive].string, "%u", SPACELEFT(values[primitive].string));
      strncat(where[primitive].string, "tos=%u", SPACELEFT(where[primitive].string));
      values[primitive].type = where[primitive].type = COUNT_IP_TOS;
      values[primitive].handler = where[primitive].handler = count_ip_tos_handler;
      primitive++;
    }
  }

  if (what_to_count & COUNT_IP_PROTO) {
    if (primitive) {
      strncat(insert_clause, ", ", SPACELEFT(insert_clause));
      strncat(values[primitive].string, ", ", sizeof(values[primitive].string));
      strncat(where[primitive].string, " AND ", sizeof(where[primitive].string));
    }
    strncat(insert_clause, "ip_proto", SPACELEFT(insert_clause));
    strncat(values[primitive].string, "%u", SPACELEFT(values[primitive].string));
    strncat(where[primitive].string, "ip_proto=%u", SPACELEFT(where[primitive].string));
    values[primitive].type = where[primitive].type = COUNT_IP_PROTO;
    values[primitive].handler = where[primitive].handler = PG_count_ip_proto_handler;
    primitive++;
  }

  if (what_to_count & COUNT_ID) {
    int count_it = FALSE;

    if ((lh.sql_table_version < 2) && !assume_custom_table) {
      if (lh.what_to_count & COUNT_ID) {
        printf("ERROR: The use of IDs requires SQL table version 2. Exiting.\n");
        exit(1);
      }
      else what_to_count ^= COUNT_ID;
    }
    else count_it = TRUE;

    if (count_it) {
      if (primitive) {
        strncat(insert_clause, ", ", SPACELEFT(insert_clause));
        strncat(values[primitive].string, ", ", sizeof(values[primitive].string));
        strncat(where[primitive].string, " AND ", sizeof(where[primitive].string));
      }
      strncat(insert_clause, "agent_id", SPACELEFT(insert_clause));
      strncat(values[primitive].string, "%u", SPACELEFT(values[primitive].string));
      strncat(where[primitive].string, "agent_id=%u", SPACELEFT(where[primitive].string));
      values[primitive].type = where[primitive].type = COUNT_ID;
      values[primitive].handler = where[primitive].handler = count_id_handler;
      primitive++;
    }
  }

  if (fakes & FAKE_SRC_MAC) {
    if (primitive) {
      strncat(insert_clause, ", ", SPACELEFT(insert_clause));
      strncat(values[primitive].string, ", ", sizeof(values[primitive].string));
      strncat(where[primitive].string, " AND ", sizeof(where[primitive].string));
    }
    strncat(insert_clause, "mac_src", SPACELEFT(insert_clause));
    strncat(values[primitive].string, "\'%s\'", SPACELEFT(values[primitive].string));
    strncat(where[primitive].string, "mac_src=\'%s\'", SPACELEFT(where[primitive].string));
    values[primitive].type = where[primitive].type = FAKE_SRC_MAC;
    values[primitive].handler = where[primitive].handler = fake_mac_handler;
    primitive++;
  }

  if (fakes & FAKE_DST_MAC) {
    if (primitive) {
      strncat(insert_clause, ", ", SPACELEFT(insert_clause));
      strncat(values[primitive].string, ", ", sizeof(values[primitive].string));
      strncat(where[primitive].string, " AND ", sizeof(where[primitive].string));
    }
    strncat(insert_clause, "mac_dst", SPACELEFT(insert_clause));
    strncat(values[primitive].string, "\'%s\'", SPACELEFT(values[primitive].string));
    strncat(where[primitive].string, "mac_dst=\'%s\'", SPACELEFT(where[primitive].string));
    values[primitive].type = where[primitive].type = FAKE_DST_MAC;
    values[primitive].handler = where[primitive].handler = fake_mac_handler;
    primitive++;
  }

  if (fakes & FAKE_SRC_HOST) {
    if (primitive) {
      strncat(insert_clause, ", ", SPACELEFT(insert_clause));
      strncat(values[primitive].string, ", ", sizeof(values[primitive].string));
      strncat(where[primitive].string, " AND ", sizeof(where[primitive].string));
    }
    strncat(insert_clause, "ip_src", SPACELEFT(insert_clause));
    strncat(values[primitive].string, "\'%s\'", SPACELEFT(values[primitive].string));
    strncat(where[primitive].string, "ip_src=\'%s\'", SPACELEFT(where[primitive].string));
    values[primitive].type = where[primitive].type = FAKE_SRC_HOST;
    values[primitive].handler = where[primitive].handler = fake_host_handler;
    primitive++;
  }

  if (fakes & FAKE_DST_HOST) {
    if (primitive) {
      strncat(insert_clause, ", ", SPACELEFT(insert_clause));
      strncat(values[primitive].string, ", ", sizeof(values[primitive].string));
      strncat(where[primitive].string, " AND ", sizeof(where[primitive].string));
    }
    strncat(insert_clause, "ip_dst", SPACELEFT(insert_clause));
    strncat(values[primitive].string, "\'%s\'", SPACELEFT(values[primitive].string));
    strncat(where[primitive].string, "ip_dst=\'%s\'", SPACELEFT(where[primitive].string));
    values[primitive].type = where[primitive].type = FAKE_DST_HOST;
    values[primitive].handler = where[primitive].handler = fake_host_handler;
    primitive++;
  }

  return primitive;
}

int PG_compose_static_queries()
{
  int primitives=0, have_flows=0;

  if (lh.what_to_count & COUNT_FLOWS || (lh.sql_table_version >= 4 && !lh.sql_optimize_clauses)) {
    lh.what_to_count |= COUNT_FLOWS;
    have_flows = TRUE;

    if (lh.sql_table_version < 4 && !lh.sql_optimize_clauses) {
      printf("ERROR: The accounting of flows requires SQL table v4. Exiting.\n");
      exit(1);
    }
  }

  /* "INSERT INTO ... VALUES ... " and "... WHERE ..." stuff */
  strncpy(where[primitives].string, " WHERE ", sizeof(where[primitives].string));
  snprintf(insert_clause, sizeof(insert_clause), "INSERT INTO %s (", sql_table);
  strncpy(values[primitives].string, " VALUES (", sizeof(values[primitives].string));
  primitives = PG_evaluate_history(primitives);
  primitives = PG_evaluate_primitives(primitives);
  strncat(insert_clause, ", packets, bytes", SPACELEFT(insert_clause));
  if (have_flows) strncat(insert_clause, ", flows", SPACELEFT(insert_clause));
  strncat(insert_clause, ")", SPACELEFT(insert_clause));

  /* "LOCK ..." stuff */
  snprintf(lock_clause, sizeof(lock_clause), "BEGIN; LOCK %s IN EXCLUSIVE MODE;", sql_table);

  /* "UPDATE ... SET ..." stuff */
  snprintf(update_clause, sizeof(update_clause), "UPDATE %s ", sql_table);
  strncat(update_clause, "SET packets=packets+%u, bytes=bytes+%lu", SPACELEFT(update_clause));
  if (have_flows) strncat(update_clause, ", flows=flows+%u", SPACELEFT(update_clause));
  if (lh.sql_history) strncat(update_clause, ", stamp_updated=CURRENT_TIMESTAMP(0)", SPACELEFT(update_clause));

  return primitives;
}

void PG_exit_gracefully(int signum)
{
  printf("\nOK: written [%u/%u] elements.\n", we, re);
  exit(0);
}

void PG_compose_conn_string(struct DBdesc *db, char *host)
{
  char *string;
  int slen = SRVBUFLEN;

  if (!db->conn_string) {
    db->conn_string = (char *) malloc(slen);
    string = db->conn_string;

    snprintf(string, slen, "dbname=%s user=%s password=%s", sql_db, sql_user, sql_pwd);
    slen -= strlen(string);
    string += strlen(string);

    if (host) snprintf(string, slen, " host=%s", host);
  }
}

int PG_DB_Connect2(struct DBdesc *db)
{
  db->desc = PQconnectdb(db->conn_string);
  if (PQstatus(db->desc) == CONNECTION_BAD) db->connected = FALSE;
  else db->connected = TRUE;

  return db->connected;
}

static int PG_affected_rows(PGresult *result)
{
  return atoi(PQcmdTuples(result));
}
