/*
 *	Filters: utility functions
 *
 *	(c) 1998 Pavel Machek <pavel@ucw.cz>
 *	(c) 2019 Maria Matejka <mq@jmq.cz>
 *
 *	Can be freely distributed and used under the terms of the GNU GPL.
 *
 */

#include "nest/bird.h"
#include "lib/lists.h"
#include "lib/resource.h"
#include "lib/socket.h"
#include "lib/string.h"
#include "lib/unaligned.h"
#include "lib/net.h"
#include "lib/ip.h"
#include "lib/hash.h"
#include "nest/route.h"
#include "nest/protocol.h"
#include "nest/iface.h"
#include "lib/attrs.h"
#include "conf/conf.h"
#include "filter/filter.h"
#include "filter/f-inst.h"
#include "filter/data.h"

static const char * const f_type_str[] = {
  [T_VOID]	= "void",
  [T_NONE]	= "none",

  [T_OPAQUE]	= "opaque byte string",
  [T_IFACE]	= "interface",

  [T_INT]	= "int",
  [T_BOOL]	= "bool",
  [T_PAIR]	= "pair",
  [T_QUAD]	= "quad",

  [T_ENUM_RTS]	= "enum rts",
  [T_ENUM_BGP_ORIGIN] = "enum bgp_origin",
  [T_ENUM_SCOPE] = "enum scope",
  [T_ENUM_RTD]	= "enum rtd",
  [T_ENUM_ROA]	= "enum roa",
  [T_ENUM_NETTYPE] = "enum nettype",
  [T_ENUM_RA_PREFERENCE] = "enum ra_preference",
  [T_ENUM_AF]	= "enum af",

  [T_IP]	= "ip",
  [T_NET]	= "prefix",
  [T_STRING]	= "string",
  [T_BYTESTRING]	= "bytestring",
  [T_PATH_MASK]	= "bgpmask",
  [T_PATH_MASK_ITEM] = "bgpmask item",
  [T_PATH]	= "bgppath",
  [T_CLIST]	= "clist",
  [T_EC]	= "ec",
  [T_ECLIST]	= "eclist",
  [T_LC]	= "lc",
  [T_LCLIST]	= "lclist",
  [T_RD]	= "rd",

  [T_ROUTE]	= "route",
  [T_ROUTES_BLOCK] = "block of routes",

  [T_SET]	= "set",
  [T_PREFIX_SET] = "prefix set",
};

STATIC_ASSERT((1 << (8 * sizeof(btype))) == ARRAY_SIZE(f_type_str));

const char *
f_type_name(btype t)
{
  return f_type_str[t] ?: "?";
}

btype
f_type_element_type(btype t)
{
  switch(t) {
    case T_PATH:   return T_INT;
    case T_CLIST:  return T_PAIR;
    case T_ECLIST: return T_EC;
    case T_LCLIST: return T_LC;
    case T_ROUTES_BLOCK: return T_ROUTE;
    default: return T_VOID;
  };
}

const struct f_trie f_const_empty_trie = { .ipv4 = -1, };
const struct f_val f_const_empty_prefix_set = {
  .type = T_PREFIX_SET,
  .val.ti = &f_const_empty_trie,
};

static void
pm_format(const struct f_path_mask *p, buffer *buf)
{
  int loop = 0;

  buffer_puts(buf, "[= ");

  for (uint i=0; i<p->len; i++)
  {
    switch(p->item[i].kind)
    {
    case PM_ASN:
      buffer_print(buf, "%u ", p->item[i].asn);
      break;

    case PM_QUESTION:
      buffer_puts(buf, "? ");
      break;

    case PM_ASTERISK:
      buffer_puts(buf, "* ");
      break;

    case PM_LOOP:
      loop = 1;
      break;

    case PM_ASN_RANGE:
      buffer_print(buf, "%u..%u ", p->item[i].from, p->item[i].to);
      break;

    case PM_ASN_SET:
      tree_format(p->item[i].set, buf);
      buffer_puts(buf, " ");
      break;

    case PM_ASN_EXPR:
      ASSERT(0);
    }

    if (loop && (p->item[i].kind != PM_LOOP))
    {
      buffer_puts(buf, "+ ");
      loop = 0;
    }
  }

  buffer_puts(buf, "=]");
}

static inline int
lcomm_cmp(lcomm v1, lcomm v2)
{
  if (v1.asn != v2.asn)
    return (v1.asn > v2.asn) ? 1 : -1;
  if (v1.ldp1 != v2.ldp1)
    return (v1.ldp1 > v2.ldp1) ? 1 : -1;
  if (v1.ldp2 != v2.ldp2)
    return (v1.ldp2 > v2.ldp2) ? 1 : -1;
  return 0;
}

/**
 * val_compare - compare two values
 * @v1: first value
 * @v2: second value
 *
 * Compares two values and returns -1, 0, 1 on <, =, > or F_CMP_ERROR on
 * error. Tree module relies on this giving consistent results so
 * that it can be used for building balanced trees.
 */
int
val_compare(const struct f_val *v1, const struct f_val *v2)
{
  if (v1->type != v2->type) {
    if (v1->type == T_VOID)	/* Hack for else */
      return -1;
    if (v2->type == T_VOID)
      return 1;

    /* IP->Quad implicit conversion */
    if ((v1->type == T_QUAD) && val_is_ip4(v2))
      return uint_cmp(v1->val.i, ipa_to_u32(v2->val.ip));
    if (val_is_ip4(v1) && (v2->type == T_QUAD))
      return uint_cmp(ipa_to_u32(v1->val.ip), v2->val.i);

    DBG( "Types do not match in val_compare\n" );
    return F_CMP_ERROR;
  }

  switch (v1->type) {
  case T_VOID:
    return 0;
  case T_ENUM:
  case T_INT:
  case T_BOOL:
  case T_PAIR:
  case T_QUAD:
    return uint_cmp(v1->val.i, v2->val.i);
  case T_EC:
  case T_RD:
    return u64_cmp(v1->val.ec, v2->val.ec);
  case T_LC:
    return lcomm_cmp(v1->val.lc, v2->val.lc);
  case T_IP:
    return ipa_compare(v1->val.ip, v2->val.ip);
  case T_NET:
    return net_compare(v1->val.net, v2->val.net);
  case T_STRING:
    return strcmp(v1->val.s, v2->val.s);
  case T_PATH:
    return as_path_compare(v1->val.ad, v2->val.ad);
  case T_ROUTE:
  case T_ROUTES_BLOCK:
  default:
    return F_CMP_ERROR;
  }
}

static inline int
bs_same(const struct adata *bs1, const struct adata *bs2)
{
  return (bs1->length == bs2->length) && !memcmp(bs1->data, bs2->data, bs1->length);
}

static inline int
pmi_same(const struct f_path_mask_item *mi1, const struct f_path_mask_item *mi2)
{
  if (mi1->kind != mi2->kind)
    return 0;

  switch (mi1->kind) {
    case PM_ASN:
      if (mi1->asn != mi2->asn)
	return 0;
      break;
    case PM_ASN_EXPR:
      if (!f_same(mi1->expr, mi2->expr))
	return 0;
      break;
    case PM_ASN_RANGE:
      if (mi1->from != mi2->from)
	return 0;
      if (mi1->to != mi2->to)
	return 0;
      break;
    case PM_ASN_SET:
      if (!same_tree(mi1->set, mi2->set))
	return 0;
      break;
  }

  return 1;
}

static int
pm_same(const struct f_path_mask *m1, const struct f_path_mask *m2)
{
  if (m1->len != m2->len)
    return 0;

  for (uint i=0; i<m1->len; i++)
    if (!pmi_same(&(m1->item[i]), &(m2->item[i])))
      return 0;

  return 1;
}

/**
 * val_same - compare two values
 * @v1: first value
 * @v2: second value
 *
 * Compares two values and returns 1 if they are same and 0 if not.
 * Comparison of values of different types is valid and returns 0.
 */
int
val_same(const struct f_val *v1, const struct f_val *v2)
{
  int rc;

  rc = val_compare(v1, v2);
  if (rc != F_CMP_ERROR)
    return !rc;

  if (v1->type != v2->type)
    return 0;

  switch (v1->type) {
  case T_BYTESTRING:
    return bs_same(v1->val.bs, v2->val.bs);
  case T_PATH_MASK:
    return pm_same(v1->val.path_mask, v2->val.path_mask);
  case T_PATH_MASK_ITEM:
    return pmi_same(&(v1->val.pmi), &(v2->val.pmi));
  case T_PATH:
  case T_CLIST:
  case T_ECLIST:
  case T_LCLIST:
    return adata_same(v1->val.ad, v2->val.ad);
  case T_SET:
    return same_tree(v1->val.t, v2->val.t);
  case T_PREFIX_SET:
    return trie_same(v1->val.ti, v2->val.ti);
  case T_ROUTE:
    return rte_same(v1->val.rte, v2->val.rte);
  case T_ROUTES_BLOCK:
    if (v1->val.rte_block.len != v2->val.rte_block.len)
      return 0;
    for (uint i=0; i < v1->val.rte_block.len; i++)
      if (!rte_same(v1->val.rte_block.rte[i], v2->val.rte_block.rte[i]))
	return 0;
    return 1;
  default:
    bug("Invalid type in val_same(): %x", v1->type);
  }
}

int
clist_set_type(const struct f_tree *set, struct f_val *v)
{
  if (!set)
  {
    v->type = T_VOID;
    return 1;
  }

  switch (set->from.type)
  {
  case T_PAIR:
    v->type = T_PAIR;
    return 1;

  case T_QUAD:
    v->type = T_QUAD;
    return 1;

  case T_IP:
    if (val_is_ip4(&(set->from)) && val_is_ip4(&(set->to)))
    {
      v->type = T_QUAD;
      return 1;
    }
    /* Fall through */
  default:
    v->type = T_VOID;
    return 0;
  }
}

int
clist_match_set(const struct adata *clist, const struct f_tree *set)
{
  if (!clist)
    return 0;

  struct f_val v;
  if (!clist_set_type(set, &v))
    return F_CMP_ERROR;

  u32 *l = (u32 *) clist->data;
  u32 *end = l + clist->length/4;

  while (l < end) {
    v.val.i = *l++;
    if (find_tree(set, &v))
      return 1;
  }
  return 0;
}

int
eclist_match_set(const struct adata *list, const struct f_tree *set)
{
  if (!list)
    return 0;

  if (!eclist_set_type(set))
    return F_CMP_ERROR;

  struct f_val v;
  u32 *l = int_set_get_data(list);
  int len = int_set_get_size(list);
  int i;

  v.type = T_EC;
  for (i = 0; i < len; i += 2) {
    v.val.ec = ec_get(l, i);
    if (find_tree(set, &v))
      return 1;
  }

  return 0;
}

int
lclist_match_set(const struct adata *list, const struct f_tree *set)
{
  if (!list)
    return 0;

  if (!lclist_set_type(set))
    return F_CMP_ERROR;

  struct f_val v;
  u32 *l = int_set_get_data(list);
  int len = int_set_get_size(list);
  int i;

  v.type = T_LC;
  for (i = 0; i < len; i += 3) {
    v.val.lc = lc_get(l, i);
    if (find_tree(set, &v))
      return 1;
  }

  return 0;
}

const struct adata *
clist_filter(struct linpool *pool, const struct adata *list, const struct f_val *set, int pos)
{
  if (!list)
    return NULL;

  int tree = (set->type == T_SET);	/* 1 -> set is T_SET, 0 -> set is T_CLIST */
  struct f_val v;
  if (tree)
    clist_set_type(set->val.t, &v);
  else
    v.type = T_PAIR;

  int len = int_set_get_size(list);
  u32 *l = int_set_get_data(list);
  u32 tmp[len];
  u32 *k = tmp;
  u32 *end = l + len;

  while (l < end) {
    v.val.i = *l++;
    /* pos && member(val, set) || !pos && !member(val, set),  member() depends on tree */
    if ((tree ? !!find_tree(set->val.t, &v) : int_set_contains(set->val.ad, v.val.i)) == pos)
      *k++ = v.val.i;
  }

  uint nl = (k - tmp) * sizeof(u32);
  if (nl == list->length)
    return list;

  struct adata *res = lp_alloc_adata(pool, nl);
  memcpy(res->data, tmp, nl);
  return res;
}

const struct adata *
eclist_filter(struct linpool *pool, const struct adata *list, const struct f_val *set, int pos)
{
  if (!list)
    return NULL;

  int tree = (set->type == T_SET);	/* 1 -> set is T_SET, 0 -> set is T_CLIST */
  struct f_val v;

  int len = int_set_get_size(list);
  u32 *l = int_set_get_data(list);
  u32 tmp[len];
  u32 *k = tmp;
  int i;

  v.type = T_EC;
  for (i = 0; i < len; i += 2) {
    v.val.ec = ec_get(l, i);
    /* pos && member(val, set) || !pos && !member(val, set),  member() depends on tree */
    if ((tree ? !!find_tree(set->val.t, &v) : ec_set_contains(set->val.ad, v.val.ec)) == pos) {
      *k++ = l[i];
      *k++ = l[i+1];
    }
  }

  uint nl = (k - tmp) * sizeof(u32);
  if (nl == list->length)
    return list;

  struct adata *res = lp_alloc_adata(pool, nl);
  memcpy(res->data, tmp, nl);
  return res;
}

const struct adata *
lclist_filter(struct linpool *pool, const struct adata *list, const struct f_val *set, int pos)
{
  if (!list)
    return NULL;

  int tree = (set->type == T_SET);	/* 1 -> set is T_SET, 0 -> set is T_CLIST */
  struct f_val v;

  int len = int_set_get_size(list);
  u32 *l = int_set_get_data(list);
  u32 tmp[len];
  u32 *k = tmp;
  int i;

  v.type = T_LC;
  for (i = 0; i < len; i += 3) {
    v.val.lc = lc_get(l, i);
    /* pos && member(val, set) || !pos && !member(val, set),  member() depends on tree */
    if ((tree ? !!find_tree(set->val.t, &v) : lc_set_contains(set->val.ad, v.val.lc)) == pos)
      k = lc_copy(k, l+i);
  }

  uint nl = (k - tmp) * sizeof(u32);
  if (nl == list->length)
    return list;

  struct adata *res = lp_alloc_adata(pool, nl);
  memcpy(res->data, tmp, nl);
  return res;
}

/**
 * val_in_range - implement |~| operator
 * @v1: element
 * @v2: set
 *
 * Checks if @v1 is element (|~| operator) of @v2.
 */
int
val_in_range(const struct f_val *v1, const struct f_val *v2)
{
  if ((v1->type == T_PATH) && (v2->type == T_PATH_MASK))
    return as_path_match(v1->val.ad, v2->val.path_mask);

  if ((v1->type == T_INT) && (v2->type == T_PATH))
    return as_path_contains(v2->val.ad, v1->val.i, 1);

  if (((v1->type == T_PAIR) || (v1->type == T_QUAD)) && (v2->type == T_CLIST))
    return int_set_contains(v2->val.ad, v1->val.i);
  /* IP->Quad implicit conversion */
  if (val_is_ip4(v1) && (v2->type == T_CLIST))
    return int_set_contains(v2->val.ad, ipa_to_u32(v1->val.ip));

  if ((v1->type == T_EC) && (v2->type == T_ECLIST))
    return ec_set_contains(v2->val.ad, v1->val.ec);

  if ((v1->type == T_LC) && (v2->type == T_LCLIST))
    return lc_set_contains(v2->val.ad, v1->val.lc);

  if ((v1->type == T_STRING) && (v2->type == T_STRING))
    return patmatch(v2->val.s, v1->val.s);

  if ((v1->type == T_IP) && (v2->type == T_NET))
    return ipa_in_netX(v1->val.ip, v2->val.net);

  if ((v1->type == T_NET) && (v2->type == T_NET))
    return net_in_netX(v1->val.net, v2->val.net);

  if ((v1->type == T_NET) && (v2->type == T_PREFIX_SET))
    return trie_match_net(v2->val.ti, v1->val.net);

  if (v2->type != T_SET)
    return F_CMP_ERROR;

  if (!v2->val.t)
    return 0;

  /* With integrated Quad<->IP implicit conversion */
  if ((v1->type == v2->val.t->from.type) ||
      ((v1->type == T_QUAD) && val_is_ip4(&(v2->val.t->from)) && val_is_ip4(&(v2->val.t->to))))
    return !!find_tree(v2->val.t, v1);

  if (v1->type == T_CLIST)
    return clist_match_set(v1->val.ad, v2->val.t);

  if (v1->type == T_ECLIST)
    return eclist_match_set(v1->val.ad, v2->val.t);

  if (v1->type == T_LCLIST)
    return lclist_match_set(v1->val.ad, v2->val.t);

  if (v1->type == T_PATH)
    return as_path_match_set(v1->val.ad, v2->val.t);

  return F_CMP_ERROR;
}

uint
val_hash(struct f_val *v)
{
  u64 haux;
  mem_hash_init(&haux);
  mem_hash_mix_f_val(&haux, v);
  return mem_hash_value(&haux);
}

void
mem_hash_mix_f_val(u64 *h, struct f_val *v)
{
  mem_hash_mix_num(h, v->type);

#define MX(k) mem_hash_mix(h, &IT(k), sizeof IT(k));
#define IT(k) v->val.k

  switch (v->type)
  {
    case T_VOID:
      break;
    case T_INT:
    case T_BOOL:
    case T_PAIR:
    case T_QUAD:
    case T_ENUM:
      MX(i);
      break;
    case T_EC:
    case T_RD:
      MX(ec);
      break;
    case T_LC:
      MX(lc);
      break;
    case T_IP:
      MX(ip);
      break;
    case T_NET:
      mem_hash_mix_num(h, net_hash(IT(net)));
      break;
    case T_STRING:
      mem_hash_mix_str(h, IT(s));
      break;
    case T_PATH_MASK:
      mem_hash_mix(h, IT(path_mask), sizeof(*IT(path_mask)) + IT(path_mask)->len * sizeof (IT(path_mask)->item));
      break;
    case T_PATH:
    case T_CLIST:
    case T_ECLIST:
    case T_LCLIST:
    case T_BYTESTRING:
    case T_ROA_AGGREGATED:
      mem_hash_mix(h, IT(ad)->data, IT(ad)->length);
      break;
    case T_SET:
      MX(t);
      break;
    case T_PREFIX_SET:
      MX(ti);
      break;

    case T_NONE:
    case T_PATH_MASK_ITEM:
    case T_ROUTE:
    case T_ROUTES_BLOCK:
    case T_OPAQUE:
    case T_NEXTHOP_LIST:
    case T_HOSTENTRY:
    case T_IFACE:
    case T_PTR:
    case T_ENUM_STATE:
    case T_BTIME:
    case T_BMP_CLOSING:
      bug("Invalid type %s in f_val hashing", f_type_name(v->type));
  }
}

/*
 * rte_format - format route information
 */
static void
rte_format(const struct rte *rte, buffer *buf)
{
  if (rte)
    buffer_print(buf, "Route [%d] to %N from %s via %s",
                 rte->src->global_id, rte->net,
                 rte->sender->req->name,
                 rte->src->owner->name);
  else
    buffer_puts(buf, "[No route]");
}

static void
rte_block_format(const struct rte_block *block, buffer *buf)
{
  buffer_print(buf, "Block of routes:");

  for (uint i = 0; i < block->len; i++)
  {
    buffer_print(buf, "%s%d: ", i ? "; " : " ", i);
    rte_format(block->rte[i], buf);
  }
}

/*
 * val_format - format filter value
 */
void
val_format(const struct f_val *v, buffer *buf)
{
  char buf2[1024];
  switch (v->type)
  {
  case T_VOID:	buffer_puts(buf, "(void)"); return;
  case T_BOOL:	buffer_puts(buf, v->val.i ? "TRUE" : "FALSE"); return;
  case T_INT:	buffer_print(buf, "%u", v->val.i); return;
  case T_STRING: buffer_print(buf, "%s", v->val.s); return;
  case T_BYTESTRING: bstrbintohex(v->val.bs->data, v->val.bs->length, buf2, 1000, ':'); buffer_print(buf, "%s", buf2); return;
  case T_IP:	buffer_print(buf, "%I", v->val.ip); return;
  case T_NET:   buffer_print(buf, "%N", v->val.net); return;
  case T_PAIR:	buffer_print(buf, "(%u,%u)", v->val.i >> 16, v->val.i & 0xffff); return;
  case T_QUAD:	buffer_print(buf, "%R", v->val.i); return;
  case T_EC:	ec_format(buf2, v->val.ec); buffer_print(buf, "%s", buf2); return;
  case T_LC:	lc_format(buf2, v->val.lc); buffer_print(buf, "%s", buf2); return;
  case T_RD:	rd_format(v->val.ec, buf2, 1024); buffer_print(buf, "%s", buf2); return;
  case T_PREFIX_SET: trie_format(v->val.ti, buf); return;
  case T_SET:	tree_format(v->val.t, buf); return;
  case T_ENUM:	buffer_print(buf, "(enum %x)%u", v->type, v->val.i); return;
  case T_PATH:	as_path_format(v->val.ad, buf2, 1000); buffer_print(buf, "(path %s)", buf2); return;
  case T_CLIST:	int_set_format(v->val.ad, 1, -1, buf2, 1000); buffer_print(buf, "(clist %s)", buf2); return;
  case T_ECLIST: ec_set_format(v->val.ad, -1, buf2, 1000); buffer_print(buf, "(eclist %s)", buf2); return;
  case T_LCLIST: lc_set_format(v->val.ad, -1, buf2, 1000); buffer_print(buf, "(lclist %s)", buf2); return;
  case T_PATH_MASK: pm_format(v->val.path_mask, buf); return;
  case T_ROUTE: rte_format(v->val.rte, buf); return;
  case T_ROUTES_BLOCK: rte_block_format(&v->val.rte_block, buf); return;
  default:	buffer_print(buf, "[unknown type %x]", v->type); return;
  }
}

char *
val_format_str(struct linpool *lp, const struct f_val *v) {
  buffer b;
  STACK_BUFFER_INIT(b, 1024);
  val_format(v, &b);
  return lp_strdup(lp, b.start);
}


static char val_dump_buffer[1024];
const char *
val_dump(const struct f_val *v) {
  static buffer b = {
    .start = val_dump_buffer,
    .end = val_dump_buffer + 1024,
  };
  b.pos = b.start;
  val_format(v, &b);
  return val_dump_buffer;
}

struct f_val *
lp_val_copy(struct linpool *lp, const struct f_val *v)
{
  switch (v->type)
  {
    case T_VOID:
    case T_BOOL:
    case T_INT:
    case T_IP:
    case T_PAIR:
    case T_QUAD:
    case T_EC:
    case T_LC:
    case T_RD:
    case T_ENUM:
    case T_PATH_MASK_ITEM:
      /* These aren't embedded but there is no need to copy them */
    case T_SET:
    case T_PREFIX_SET:
    case T_PATH_MASK:
    case T_IFACE:
      {
	struct f_val *out = lp_alloc(lp, sizeof(*out));
	*out = *v;
	return out;
      }

    case T_NET:
      {
	struct {
	  struct f_val val;
	  net_addr net[0];
	} *out = lp_alloc(lp, sizeof(*out) + v->val.net->length);
	out->val = *v;
	out->val.val.net = out->net;
	net_copy(out->net, v->val.net);
	return &out->val;
      }

    case T_STRING:
      {
	uint len = strlen(v->val.s);
	struct {
	  struct f_val val;
	  char buf[0];
	} *out = lp_alloc(lp, sizeof(*out) + len + 1);
	out->val = *v;
	out->val.val.s = out->buf;
	memcpy(out->buf, v->val.s, len+1);
	return &out->val;
      }

    case T_PATH:
    case T_CLIST:
    case T_ECLIST:
    case T_LCLIST:
      {
	struct {
	  struct f_val val;
	  struct adata ad;
	} *out = lp_alloc(lp, sizeof(*out) + v->val.ad->length);
	out->val = *v;
	out->val.val.ad = &out->ad;
	memcpy(&out->ad, v->val.ad, v->val.ad->length);
	return &out->val;
      }

    default:
      bug("Unknown type in value copy: %d", v->type);
  }
}