// SPDX-License-Identifier: GPL-2.0
/*
 * Base unit test (KUnit) API.
 *
 * Copyright (C) 2019, Google LLC.
 * Author: Brendan Higgins <brendanhiggins@google.com>
 */

#include <linux/sched/debug.h>
#include <kunit/test.h>
#include <kunit/try-catch.h>

static bool kunit_get_success(struct kunit *test)
{
	unsigned long flags;
	bool success;

	spin_lock_irqsave(&test->lock, flags);
	success = test->success;
	spin_unlock_irqrestore(&test->lock, flags);

	return success;
}

static void kunit_set_success(struct kunit *test, bool success)
{
	unsigned long flags;

	spin_lock_irqsave(&test->lock, flags);
	test->success = success;
	spin_unlock_irqrestore(&test->lock, flags);
}

static bool kunit_get_death_test(struct kunit *test)
{
	unsigned long flags;
	bool death_test;

	spin_lock_irqsave(&test->lock, flags);
	death_test = test->death_test;
	spin_unlock_irqrestore(&test->lock, flags);

	return death_test;
}

static int kunit_vprintk_emit(int level, const char *fmt, va_list args)
{
	return vprintk_emit(0, level, NULL, 0, fmt, args);
}

static int kunit_printk_emit(int level, const char *fmt, ...)
{
	va_list args;
	int ret;

	va_start(args, fmt);
	ret = kunit_vprintk_emit(level, fmt, args);
	va_end(args);

	return ret;
}

static void kunit_vprintk(const struct kunit *test,
			  const char *level,
			  struct va_format *vaf)
{
	kunit_printk_emit(level[1] - '0', "\t# %s: %pV", test->name, vaf);
}

static bool kunit_has_printed_tap_version;

static void kunit_print_tap_version(void)
{
	if (!kunit_has_printed_tap_version) {
		kunit_printk_emit(LOGLEVEL_INFO, "TAP version 14\n");
		kunit_has_printed_tap_version = true;
	}
}

static size_t kunit_test_cases_len(struct kunit_case *test_cases)
{
	struct kunit_case *test_case;
	size_t len = 0;

	for (test_case = test_cases; test_case->run_case; test_case++)
		len++;

	return len;
}

static void kunit_print_subtest_start(struct kunit_module *module)
{
	kunit_print_tap_version();
	kunit_printk_emit(LOGLEVEL_INFO, "\t# Subtest: %s\n", module->name);
	kunit_printk_emit(LOGLEVEL_INFO,
			  "\t1..%zd\n",
			  kunit_test_cases_len(module->test_cases));
}

static void kunit_print_ok_not_ok(bool should_indent,
				  bool is_ok,
				  size_t test_number,
				  const char *description)
{
	const char *indent, *ok_not_ok;

	if (should_indent)
		indent = "\t";
	else
		indent = "";

	if (is_ok)
		ok_not_ok = "ok";
	else
		ok_not_ok = "not ok";

	kunit_printk_emit(LOGLEVEL_INFO,
			  "%s%s %zd - %s\n",
			  indent, ok_not_ok, test_number, description);
}

static bool kunit_module_has_succeeded(struct kunit_module *module)
{
	struct kunit_case *test_case;
	bool success = true;

	for (test_case = module->test_cases; test_case->run_case; test_case++)
		if (!test_case->success)
			success = false;

	return success;
}

size_t kunit_module_counter = 1;

static void kunit_print_subtest_end(struct kunit_module *module)
{
	kunit_print_ok_not_ok(false,
			      kunit_module_has_succeeded(module),
			      kunit_module_counter++,
			      module->name);
}

static void kunit_print_test_case_ok_not_ok(struct kunit_case *test_case,
					    size_t test_number)
{
	kunit_print_ok_not_ok(true,
			      test_case->success,
			      test_number,
			      test_case->name);
}

void kunit_fail(struct kunit *test, struct kunit_stream *stream)
{
	kunit_set_success(test, false);
	kunit_stream_set_level(stream, KERN_ERR);
	kunit_stream_commit(stream);
}

void __noreturn kunit_abort(struct kunit *test)
{
	kunit_set_death_test(test, true);

	kunit_try_catch_throw(&test->try_catch);

	/*
	 * Throw could not abort from test.
	 *
	 * XXX: we should never reach this line! As kunit_try_catch_throw is
	 * marked __noreturn.
	 */
	WARN_ONCE(true, "Throw could not abort from test!\n");
}

void kunit_init_test(struct kunit *test, const char *name)
{
	spin_lock_init(&test->lock);
	INIT_LIST_HEAD(&test->resources);
	test->name = name;
}

/*
 * Initializes and runs test case. Does not clean up or do post validations.
 */
static void kunit_run_case_internal(struct kunit *test,
				    struct kunit_module *module,
				    struct kunit_case *test_case)
{
	int ret;

	if (module->init) {
		ret = module->init(test);
		if (ret) {
			kunit_err(test, "failed to initialize: %d\n", ret);
			kunit_set_success(test, false);
			return;
		}
	}

	test_case->run_case(test);
}

static void kunit_case_internal_cleanup(struct kunit *test)
{
	kunit_cleanup(test);
}

/*
 * Performs post validations and cleanup after a test case was run.
 * XXX: Should ONLY BE CALLED AFTER kunit_run_case_internal!
 */
static void kunit_run_case_cleanup(struct kunit *test,
				   struct kunit_module *module,
				   struct kunit_case *test_case)
{
	if (module->exit)
		module->exit(test);

	kunit_case_internal_cleanup(test);
}

/*
 * Handles an unexpected crash in a test case.
 */
static void kunit_handle_test_crash(struct kunit *test,
				   struct kunit_module *module,
				   struct kunit_case *test_case)
{
	kunit_err(test, "kunit test case crashed!");
	/*
	 * TODO(brendanhiggins@google.com): This prints the stack trace up
	 * through this frame, not up to the frame that caused the crash.
	 */
	show_stack(NULL, NULL);

	kunit_case_internal_cleanup(test);
}

struct kunit_try_catch_context {
	struct kunit *test;
	struct kunit_module *module;
	struct kunit_case *test_case;
};

static void kunit_try_run_case(void *data)
{
	struct kunit_try_catch_context *ctx = data;
	struct kunit *test = ctx->test;
	struct kunit_module *module = ctx->module;
	struct kunit_case *test_case = ctx->test_case;

	/*
	 * kunit_run_case_internal may encounter a fatal error; if it does,
	 * abort will be called, this thread will exit, and finally the parent
	 * thread will resume control and handle any necessary clean up.
	 */
	kunit_run_case_internal(test, module, test_case);
	/* This line may never be reached. */
	kunit_run_case_cleanup(test, module, test_case);
}

static void kunit_catch_run_case(void *data)
{
	struct kunit_try_catch_context *ctx = data;
	struct kunit *test = ctx->test;
	struct kunit_module *module = ctx->module;
	struct kunit_case *test_case = ctx->test_case;
	int try_exit_code = kunit_try_catch_get_result(&test->try_catch);

	if (try_exit_code) {
		kunit_set_success(test, false);
		/*
		 * Test case could not finish, we have no idea what state it is
		 * in, so don't do clean up.
		 */
		if (try_exit_code == -ETIMEDOUT)
			kunit_err(test, "test case timed out\n");
		/*
		 * Unknown internal error occurred preventing test case from
		 * running, so there is nothing to clean up.
		 */
		else
			kunit_err(test, "internal error occurred preventing test case from running: %d\n",
				  try_exit_code);
		return;
	}

	if (kunit_get_death_test(test)) {
		/*
		 * EXPECTED DEATH: kunit_run_case_internal encountered
		 * anticipated fatal error. Everything should be in a safe
		 * state.
		 */
		kunit_run_case_cleanup(test, module, test_case);
	} else {
		/*
		 * UNEXPECTED DEATH: kunit_run_case_internal encountered an
		 * unanticipated fatal error. We have no idea what the state of
		 * the test case is in.
		 */
		kunit_handle_test_crash(test, module, test_case);
		kunit_set_success(test, false);
	}
}

/*
 * Performs all logic to run a test case. It also catches most errors that
 * occurs in a test case and reports them as failures.
 */
static void kunit_run_case_catch_errors(struct kunit_module *module,
					struct kunit_case *test_case)
{
	struct kunit_try_catch_context context;
	struct kunit_try_catch *try_catch;
	struct kunit test;

	kunit_init_test(&test, test_case->name);
	try_catch = &test.try_catch;
	kunit_set_success(&test, true);
	kunit_set_death_test(&test, false);

	kunit_try_catch_init(try_catch,
			     &test,
			     kunit_try_run_case,
			     kunit_catch_run_case);
	context.test = &test;
	context.module = module;
	context.test_case = test_case;
	kunit_try_catch_run(try_catch, &context);

	test_case->success = kunit_get_success(&test);
}

int kunit_run_tests(struct kunit_module *module)
{
	struct kunit_case *test_case;
	size_t test_case_count = 1;

	kunit_print_subtest_start(module);

	for (test_case = module->test_cases; test_case->run_case; test_case++) {
		kunit_run_case_catch_errors(module, test_case);
		kunit_print_test_case_ok_not_ok(test_case, test_case_count++);
	}

	kunit_print_subtest_end(module);

	return 0;
}

struct kunit_resource *kunit_alloc_resource(struct kunit *test,
					    kunit_resource_init_t init,
					    kunit_resource_free_t free,
					    void *context)
{
	struct kunit_resource *res;
	unsigned long flags;
	int ret;

	res = kzalloc(sizeof(*res), GFP_KERNEL);
	if (!res)
		return NULL;

	ret = init(res, context);
	if (ret)
		return NULL;

	res->free = free;
	spin_lock_irqsave(&test->lock, flags);
	list_add_tail(&res->node, &test->resources);
	spin_unlock_irqrestore(&test->lock, flags);

	return res;
}

void kunit_free_resource(struct kunit *test, struct kunit_resource *res)
{
	res->free(res);
	list_del(&res->node);
	kfree(res);
}

struct kunit_kmalloc_params {
	size_t size;
	gfp_t gfp;
};

static int kunit_kmalloc_init(struct kunit_resource *res, void *context)
{
	struct kunit_kmalloc_params *params = context;

	res->allocation = kmalloc(params->size, params->gfp);
	if (!res->allocation)
		return -ENOMEM;

	return 0;
}

static void kunit_kmalloc_free(struct kunit_resource *res)
{
	kfree(res->allocation);
}

void *kunit_kmalloc(struct kunit *test, size_t size, gfp_t gfp)
{
	struct kunit_kmalloc_params params;
	struct kunit_resource *res;

	params.size = size;
	params.gfp = gfp;

	res = kunit_alloc_resource(test,
				   kunit_kmalloc_init,
				   kunit_kmalloc_free,
				   &params);

	if (res)
		return res->allocation;
	else
		return NULL;
}

void kunit_cleanup(struct kunit *test)
{
	struct kunit_resource *resource, *resource_safe;
	unsigned long flags;

	spin_lock_irqsave(&test->lock, flags);
	list_for_each_entry_safe(resource,
				 resource_safe,
				 &test->resources,
				 node) {
		kunit_free_resource(test, resource);
	}
	spin_unlock_irqrestore(&test->lock, flags);
}

void kunit_printk(const char *level,
		  const struct kunit *test,
		  const char *fmt, ...)
{
	struct va_format vaf;
	va_list args;

	va_start(args, fmt);

	vaf.fmt = fmt;
	vaf.va = &args;

	kunit_vprintk(test, level, &vaf);

	va_end(args);
}

void kunit_expect_binary_msg(struct kunit *test,
			     long long left, const char *left_name,
			     long long right, const char *right_name,
			     bool compare_result,
			     const char *compare_name,
			     const char *file,
			     const char *line,
			     const char *fmt, ...)
{
	struct kunit_stream *stream = kunit_expect_start(test, file, line);
	struct va_format vaf;
	va_list args;

	kunit_stream_add(stream,
			 "Expected %s %s %s, but\n",
			 left_name, compare_name, right_name);
	kunit_stream_add(stream, "\t\t%s == %lld\n", left_name, left);
	kunit_stream_add(stream, "\t\t%s == %lld", right_name, right);

	if (fmt) {
		va_start(args, fmt);

		vaf.fmt = fmt;
		vaf.va = &args;

		kunit_stream_add(stream, "\n%pV", &vaf);

		va_end(args);
	}

	kunit_expect_end(test, compare_result, stream);
}

void kunit_expect_ptr_binary_msg(struct kunit *test,
				 void *left, const char *left_name,
				 void *right, const char *right_name,
				 bool compare_result,
				 const char *compare_name,
				 const char *file,
				 const char *line,
				 const char *fmt, ...)
{
	struct kunit_stream *stream = kunit_expect_start(test, file, line);
	struct va_format vaf;
	va_list args;

	kunit_stream_add(stream,
			 "Expected %s %s %s, but\n",
			 left_name, compare_name, right_name);
	kunit_stream_add(stream, "\t\t%s == %pK\n", left_name, left);
	kunit_stream_add(stream, "\t\t%s == %pK", right_name, right);

	if (fmt) {
		va_start(args, fmt);

		vaf.fmt = fmt;
		vaf.va = &args;

		kunit_stream_add(stream, "\n%pV", &vaf);

		va_end(args);
	}

	kunit_expect_end(test, compare_result, stream);
}

void kunit_assert_binary_msg(struct kunit *test,
			     long long left, const char *left_name,
			     long long right, const char *right_name,
			     bool compare_result,
			     const char *compare_name,
			     const char *file,
			     const char *line,
			     const char *fmt, ...)
{
	struct kunit_stream *stream = kunit_assert_start(test, file, line);
	struct va_format vaf;
	va_list args;

	kunit_stream_add(stream,
			 "Asserted %s %s %s, but\n",
			 left_name, compare_name, right_name);
	kunit_stream_add(stream, "\t\t%s == %lld\n", left_name, left);
	kunit_stream_add(stream, "\t\t%s == %lld\n", right_name, right);

	if (fmt) {
		va_start(args, fmt);

		vaf.fmt = fmt;
		vaf.va = &args;

		kunit_stream_add(stream, "\n%pV", &vaf);

		va_end(args);
	}

	kunit_assert_end(test, compare_result, stream);
}

void kunit_assert_ptr_binary_msg(struct kunit *test,
				 void *left, const char *left_name,
				 void *right, const char *right_name,
				 bool compare_result,
				 const char *compare_name,
				 const char *file,
				 const char *line,
				 const char *fmt, ...)
{
	struct kunit_stream *stream = kunit_assert_start(test, file, line);
	struct va_format vaf;
	va_list args;

	kunit_stream_add(stream,
			 "Asserted %s %s %s, but\n",
			 left_name, compare_name, right_name);
	kunit_stream_add(stream, "\t\t%s == %pK\n", left_name, left);
	kunit_stream_add(stream, "\t\t%s == %pK", right_name, right);

	if (fmt) {
		va_start(args, fmt);

		vaf.fmt = fmt;
		vaf.va = &args;

		kunit_stream_add(stream, "\n%pV", &vaf);

		va_end(args);
	}

	kunit_assert_end(test, compare_result, stream);
}
