#include <test/test.h>
#include <test/mock.h>
#include <net/ipv6.h>
#include <net/ip6_fib.h>
#include <net/ip6_route.h>

struct rt6_test {
	struct net *net;
	struct fib6_info *rt_info;
};

__visible_for_testing struct fib6_info *
ip6_route_info_create(struct fib6_config *cfg,
		      struct netlink_ext_ack *extack);

__visible_for_testing struct rt6_exception *
__rt6_find_exception_spinlock(struct rt6_exception_bucket **bucket,
			      const struct in6_addr *daddr,
			      const struct in6_addr *saddr);

__visible_for_testing struct rt6_info *ip6_rt_pcpu_alloc(struct fib6_info *rt);

DEFINE_FUNCTION_MOCK_VOID_RETURN(rt6_remove_exception,
				 PARAMS(struct rt6_exception_bucket *, struct rt6_exception *));

static struct fib6_info *rt6_info_create_fake(struct test *test, struct net *net, struct in6_addr *addr)
{
	struct fib6_info *rt_info;
	struct fib6_config cfg = {
		.fc_nlinfo = {
			.nl_net = net,
		},
		.fc_flags = RTF_REJECT,
		.fc_dst = *addr,
		.fc_dst_len = 128,
		.fc_src_len = 0,
		.fc_ifindex = 0,
		.fc_metric = 0,
		.fc_protocol = 0,
	};
	struct netlink_ext_ack extack = {
		._msg = NULL,
	};

	rt_info = ip6_route_info_create(&cfg, &extack);
	EXPECT_EQ(test, NULL, extack._msg);
	if (extack._msg)
		test_warn(test, "Error creating rt_info: %s", extack._msg);
	ASSERT_NOT_ERR_OR_NULL(test, rt_info);

	return rt_info;
}

static struct rt6_exception *rt6_add_exception(struct test *test, struct fib6_info *rt_info, struct rt6_info *rt6_ex_info)
{
	struct net *net = dev_net(rt_info->fib6_nh.nh_dev);
	struct rt6_exception_bucket *bucket = rt_info->rt6i_exception_bucket;
	struct rt6_exception *rt6_ex;

	rt6_ex = __rt6_find_exception_spinlock(&bucket, &rt6_ex_info->rt6i_dst.addr, NULL);
	ASSERT_FALSE(test, rt6_ex);

	rt6_ex = test_kzalloc(test, sizeof(*rt6_ex), GFP_KERNEL);
	ASSERT_NOT_ERR_OR_NULL(test, rt6_ex);

	rt6_ex->rt6i = rt6_ex_info;
	hlist_add_head(&rt6_ex->hlist, &bucket->chain);
	bucket->depth++;
	net->ipv6.rt6_stats->fib_rt_cache++;

	return rt6_ex;
}

static void
rt6_age_exceptions_test_rtf_expires_after_expiration(struct test *test)
{
	struct rt6_test *ctx = test->priv;
	struct fib6_info *rt_info = ctx->rt_info;
	struct rt6_info *rt6_ex_info;
	struct mock_expectation *handle;
	struct fib6_gc_args gc_args;

	rt6_ex_info = ip6_rt_pcpu_alloc(rt_info);
	rt6_ex_info->rt6i_flags |= RTF_EXPIRES;
	rt6_ex_info->dst.expires = 1;
	rt6_add_exception(test, rt_info, rt6_ex_info);

	handle = EXPECT_CALL(rt6_remove_exception(any(test), any(test)));
	handle->min_calls_expected = 1;
	handle->max_calls_expected = 1;
	handle->action = int_return(test, 0);

	rt6_age_exceptions(rt_info, &gc_args, 2);
}

static void
rt6_age_exceptions_test_pmtu_expires_not_before_timeout(struct test *test)
{
	struct rt6_test *ctx = test->priv;
	struct fib6_info *rt_info = ctx->rt_info;
	struct rt6_info *rt6_ex_info;
	struct mock_expectation *handle;
	struct fib6_gc_args gc_args = {
		.timeout = 1,
	};

	rt6_ex_info = ip6_rt_pcpu_alloc(rt_info);
	rt6_ex_info->rt6i_flags &= ~RTF_EXPIRES;
	rt6_ex_info->dst.lastuse = 1;
	rt6_add_exception(test, rt_info, rt6_ex_info);

	handle = EXPECT_CALL(rt6_remove_exception(any(test), any(test)));
	handle->min_calls_expected = 0;
	handle->max_calls_expected = 0;
	handle->action = int_return(test, 0);

	rt6_age_exceptions(rt_info, &gc_args, 0);
}

static void
rt6_age_exceptions_test_pmtu_expires_after_timeout(struct test *test)
{
	struct rt6_test *ctx = test->priv;
	struct fib6_info *rt_info = ctx->rt_info;
	struct rt6_info *rt6_ex_info;
	struct mock_expectation *handle;
	struct fib6_gc_args gc_args = {
		.timeout = 1,
	};

	rt6_ex_info = ip6_rt_pcpu_alloc(rt_info);
	rt6_ex_info->rt6i_flags &= ~RTF_EXPIRES;
	rt6_ex_info->dst.lastuse = 1;
	rt6_add_exception(test, rt_info, rt6_ex_info);

	handle = EXPECT_CALL(rt6_remove_exception(any(test), any(test)));
	handle->min_calls_expected = 1;
	handle->max_calls_expected = 1;
	handle->action = int_return(test, 0);

	rt6_age_exceptions(rt_info, &gc_args, 3);
}

static void
rt6_age_exceptions_test_pmtu_expires_with_no_neighbor(struct test *test)
{
	struct rt6_test *ctx = test->priv;
	struct fib6_info *rt_info = ctx->rt_info;
	struct rt6_info *rt6_ex_info;
	struct mock_expectation *handle;
	struct fib6_gc_args gc_args;

	rt6_ex_info = ip6_rt_pcpu_alloc(rt_info);
	rt6_ex_info->rt6i_flags |= RTF_GATEWAY;
	rt6_ex_info->dst.expires = 1;
	rt6_add_exception(test, rt_info, rt6_ex_info);

	handle = EXPECT_CALL(rt6_remove_exception(any(test), any(test)));
	handle->min_calls_expected = 1;
	handle->max_calls_expected = 1;
	handle->action = int_return(test, 0);

	rt6_age_exceptions(rt_info, &gc_args, 0);
}

static void
rt6_age_exceptions_test_pmtu_not_expire_with_router_neighbor(struct test *test)
{
	struct rt6_test *ctx = test->priv;
	struct fib6_info *rt_info = ctx->rt_info;
	struct rt6_info *rt6_ex_info;
	struct mock_expectation *handle;
	struct fib6_gc_args gc_args;
	struct inet6_dev *in6_dev;
	struct neighbour *neigh;

	rt6_ex_info = ip6_rt_pcpu_alloc(rt_info);
	rt6_ex_info->rt6i_flags |= RTF_GATEWAY;
	rt6_ex_info->dst.expires = 1;
	rt6_add_exception(test, rt_info, rt6_ex_info);

	in6_dev = in6_dev_get(ctx->net->loopback_dev);
	ASSERT_NOT_ERR_OR_NULL(test, in6_dev);
	in6_dev->nd_parms = neigh_parms_clone(&nd_tbl.parms);
	in6_dev_put(in6_dev);
	neigh = __neigh_create(&nd_tbl, &rt6_ex_info->rt6i_gateway, ctx->net->loopback_dev, false);
	ASSERT_NOT_ERR_OR_NULL(test, neigh);
	neigh = __ipv6_neigh_lookup_noref(rt6_ex_info->dst.dev, &rt6_ex_info->rt6i_gateway);
	ASSERT_NOT_ERR_OR_NULL(test, neigh);
	neigh->flags |= NTF_ROUTER;

	handle = EXPECT_CALL(rt6_remove_exception(any(test), any(test)));
	handle->min_calls_expected = 0;
	handle->max_calls_expected = 0;
	handle->action = int_return(test, 0);

	rt6_age_exceptions(rt_info, &gc_args, 0);
}

static void
rt6_age_exceptions_test_pmtu_eventually_expire_with_router_neighbor(struct test *test)
{
	struct rt6_test *ctx = test->priv;
	struct fib6_info *rt_info = ctx->rt_info;
	struct rt6_info *rt6_ex_info;
	struct mock_expectation *handle;
	struct fib6_gc_args gc_args;
	struct inet6_dev *in6_dev;
	struct neighbour *neigh;

	rt6_ex_info = ip6_rt_pcpu_alloc(rt_info);
	rt6_ex_info->rt6i_flags |= RTF_EXPIRES | RTF_GATEWAY;
	rt6_ex_info->dst.expires = 1;
	rt6_add_exception(test, rt_info, rt6_ex_info);

	in6_dev = in6_dev_get(ctx->net->loopback_dev);
	ASSERT_NOT_ERR_OR_NULL(test, in6_dev);
	in6_dev->nd_parms = neigh_parms_clone(&nd_tbl.parms);
	in6_dev_put(in6_dev);
	neigh = __neigh_create(&nd_tbl, &rt6_ex_info->rt6i_gateway, ctx->net->loopback_dev, false);
	ASSERT_NOT_ERR_OR_NULL(test, neigh);
	neigh = __ipv6_neigh_lookup_noref(rt6_ex_info->dst.dev, &rt6_ex_info->rt6i_gateway);
	ASSERT_NOT_ERR_OR_NULL(test, neigh);
	neigh->flags |= NTF_ROUTER;

	handle = EXPECT_CALL(rt6_remove_exception(any(test), any(test)));
	handle->min_calls_expected = 1;
	handle->max_calls_expected = 1;
	handle->action = int_return(test, 0);

	rt6_age_exceptions(rt_info, &gc_args, 2);
}

static void
rt6_age_exceptions_test_pmtu_expire_with_non_router_neighbor(struct test *test)
{
	struct rt6_test *ctx = test->priv;
	struct fib6_info *rt_info = ctx->rt_info;
	struct rt6_info *rt6_ex_info;
	struct mock_expectation *handle;
	struct fib6_gc_args gc_args;
	struct inet6_dev *in6_dev;
	struct neighbour *neigh;

	rt6_ex_info = ip6_rt_pcpu_alloc(rt_info);
	rt6_ex_info->rt6i_flags |= RTF_GATEWAY;
	rt6_ex_info->dst.expires = 1;
	rt6_add_exception(test, rt_info, rt6_ex_info);

	in6_dev = in6_dev_get(ctx->net->loopback_dev);
	ASSERT_NOT_ERR_OR_NULL(test, in6_dev);
	in6_dev->nd_parms = neigh_parms_clone(&nd_tbl.parms);
	in6_dev_put(in6_dev);
	neigh = __neigh_create(&nd_tbl, &rt6_ex_info->rt6i_gateway, ctx->net->loopback_dev, false);
	ASSERT_NOT_ERR_OR_NULL(test, neigh);
	neigh = __ipv6_neigh_lookup_noref(rt6_ex_info->dst.dev, &rt6_ex_info->rt6i_gateway);
	ASSERT_NOT_ERR_OR_NULL(test, neigh);
	neigh->flags &= ~NTF_ROUTER;

	handle = EXPECT_CALL(rt6_remove_exception(any(test), any(test)));
	handle->min_calls_expected = 1;
	handle->max_calls_expected = 1;
	handle->action = int_return(test, 0);

	rt6_age_exceptions(rt_info, &gc_args, 0);
}

static const struct net_device_ops fake_net_device_ops;

static void net_device_fake_init(struct net_device *dev)
{
	struct inet6_dev *idev;

	idev = kzalloc(sizeof(*idev), GFP_KERNEL);
	if (idev)
		refcount_set(&idev->refcnt, 1);

	dev->ip6_ptr = idev;
	dev->netdev_ops = &fake_net_device_ops;
}

static struct net_device *net_device_create_fake(void)
{
	return alloc_netdev(0, "fake%d", NET_NAME_UNKNOWN, net_device_fake_init);
}

static struct net *net_create_fake(void)
{
	struct net *net;

	net = copy_net_ns(0, task_cred_xxx(current, user_ns), current->nsproxy->net_ns);
	if (net) {
		net->loopback_dev = net_device_create_fake();
	}

	return net;
}

static int rt6_test_init(struct test *test)
{
	struct rt6_test *ctx;
	struct in6_addr addr;

	ctx = test_kzalloc(test, sizeof(*ctx), GFP_KERNEL);
	if (!ctx)
		return -ENOMEM;
	test->priv = ctx;

	ctx->net = net_create_fake();
	if (!ctx->net)
		return -ENOMEM;

	ipv6_addr_set(&addr, 0, 0, 0, 0);
	ctx->rt_info = rt6_info_create_fake(test, ctx->net, &addr);

	EXPECT_FALSE(test, ctx->rt_info->rt6i_exception_bucket);
	ctx->rt_info->rt6i_exception_bucket =
			test_kzalloc(test,
				     sizeof(*ctx->rt_info->rt6i_exception_bucket) * FIB6_EXCEPTION_BUCKET_SIZE,
				     GFP_KERNEL);
	if (!ctx->rt_info->rt6i_exception_bucket)
		return -ENOMEM;

	return 0;
}

static void rt6_test_exit(struct test *test)
{
	struct rt6_test *ctx = test->priv;

	fib6_info_release(ctx->rt_info);
	// unregister_netdev(ctx->net->loopback_dev);
	__put_net(ctx->net);
}

static struct test_case rt6_test_cases[] = {
	TEST_CASE(rt6_age_exceptions_test_rtf_expires_after_expiration),
	TEST_CASE(rt6_age_exceptions_test_pmtu_expires_not_before_timeout),
	TEST_CASE(rt6_age_exceptions_test_pmtu_expires_after_timeout),
	TEST_CASE(rt6_age_exceptions_test_pmtu_expires_with_no_neighbor),
	TEST_CASE(rt6_age_exceptions_test_pmtu_not_expire_with_router_neighbor),
	TEST_CASE(rt6_age_exceptions_test_pmtu_eventually_expire_with_router_neighbor),
	TEST_CASE(rt6_age_exceptions_test_pmtu_expire_with_non_router_neighbor),
	{},
};

static struct test_module rt6_test_module = {
	.name = "rt6-test",
	.init = rt6_test_init,
	.exit = rt6_test_exit,
	.test_cases = rt6_test_cases,
};
module_test(rt6_test_module);
