kunit: test: port test abort support to x86

Fixes alignment issue with linker section based
registering of test modules.
Adds support for abort function.
Based on commit kunit: test: add support for test abort
("668e156c851bdb4f152ee86ffafe07898f64ad09").

Google-Bug-Id: 116155953
Signed-off-by: Iurii Zaikin <yzaikin@google.com>
Change-Id: Ic028ba95333c957def32587b8832337f823dbb90
diff --git a/include/asm-generic/vmlinux.lds.h b/include/asm-generic/vmlinux.lds.h
index 424a219..cbb6461 100644
--- a/include/asm-generic/vmlinux.lds.h
+++ b/include/asm-generic/vmlinux.lds.h
@@ -798,7 +798,9 @@
 		KEEP(*(.security_initcall.init))			\
 		__security_initcall_end = .;
 
+/* Alignment must be consistent with (test_module *) in include/test/test.h */
 #define KUNIT_TEST_MODULES						\
+		. = ALIGN(8);						\
 		__test_modules_start = .;				\
 		KEEP(*(.test_modules))					\
 		__test_modules_end = .;
diff --git a/include/test/test.h b/include/test/test.h
index b0007c4..1369963 100644
--- a/include/test/test.h
+++ b/include/test/test.h
@@ -13,6 +13,7 @@
 #include <linux/slab.h>
 #include <test/strerror.h>
 #include <test/test-stream.h>
+#include <test/try-catch.h>
 
 /**
  * struct test_resource - represents a *test managed resource*
@@ -172,6 +173,7 @@
 struct test {
 	void *priv;
 	/* private: internal use only. */
+	spinlock_t lock; /* Guards all mutable test state. */
 	struct list_head resources;
 	struct list_head post_conditions;
 	const char *name;
@@ -182,8 +184,18 @@
 			struct va_format *vaf);
 	void (*fail)(struct test *test, struct test_stream *stream);
 	void (*abort)(struct test *test);
+	struct test_try_catch try_catch;
 };
 
+static inline void test_set_death_test(struct test *test, bool death_test)
+{
+	unsigned long flags;
+
+	spin_lock_irqsave(&test->lock, flags);
+	test->death_test = death_test;
+	spin_unlock_irqrestore(&test->lock, flags);
+}
+
 int test_init_test(struct test *test, const char *name);
 
 int test_run_tests(struct test_module *module);
@@ -216,10 +228,18 @@
  *
  * Registers @module with the test framework. See &struct test_module for more
  * information.
+ * Hardcoding the alignment to 8 was chosen as the most likely to remain
+ * between the compiler laying out the test module pointers in the custom
+ * section and the linker script placing the custom section in the output
+ * binary. There must be no gap between the section start and the first
+ * (test_module *) entry nor between any (test_module *) entries because
+ * the test executor views the .test_modules section as an array of
+ * (test_module *) starting at __test_modules_start.
  */
 #define module_test(module) \
 		static struct test_module *__test_module_##module __used       \
-	__attribute__((__section__(".test_modules"))) = &module
+		__aligned(8) __attribute__((__section__(".test_modules"))) = \
+			&module
 
 /**
  * test_alloc_resource() - Allocates a *test managed resource*.
diff --git a/include/test/try-catch.h b/include/test/try-catch.h
new file mode 100644
index 0000000..2b2a790
--- /dev/null
+++ b/include/test/try-catch.h
@@ -0,0 +1,91 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+/*
+ * An API to allow a function, that may fail, to be executed, and recover in a
+ * controlled manner.
+ *
+ * Copyright (C) 2019, Google LLC.
+ * Author: Brendan Higgins <brendanhiggins@google.com>
+ */
+
+#ifndef _TEST_TRY_CATCH_H
+#define _TEST_TRY_CATCH_H
+
+#include <linux/types.h>
+
+typedef void (*test_try_catch_func_t)(void *);
+
+struct test;
+
+/*
+ * struct test_try_catch - provides a generic way to run code which might fail.
+ * @context: used to pass user data to the try and catch functions.
+ *
+ * test_try_catch provides a generic, architecture independent way to execute
+ * an arbitrary function of type test_try_catch_func_t which may bail out by
+ * calling test_try_catch_throw(). If test_try_catch_throw() is called, @try
+ * is stopped at the site of invocation and @catch is catch is called.
+ *
+ * struct test_try_catch provides a generic interface for the functionality
+ * needed to implement test->abort() which in turn is needed for implementing
+ * assertions. Assertions allow stating a precondition for a test simplifying
+ * how test cases are written and presented.
+ *
+ * Assertions are like expectations, except they abort (call
+ * test_try_catch_throw()) when the specified condition is not met. This is
+ * useful when you look at a test case as a logical statement about some piece
+ * of code, where assertions are the premises for the test case, and the
+ * conclusion is a set of predicates, rather expectations, that must all be
+ * true. If your premises are violated, it does not makes sense to continue.
+ */
+struct test_try_catch {
+	/* private: internal use only. */
+	void (*run)(struct test_try_catch *try_catch);
+	void __noreturn (*throw)(struct test_try_catch *try_catch);
+	struct test *test;
+	struct completion *try_completion;
+	int try_result;
+	test_try_catch_func_t try;
+	test_try_catch_func_t catch;
+	void *context;
+};
+
+/*
+ * Exposed to be overridden for other architectures.
+ */
+void test_try_catch_init_internal(struct test_try_catch *try_catch);
+
+static inline void test_try_catch_init(struct test_try_catch *try_catch,
+					struct test *test,
+					test_try_catch_func_t try,
+					test_try_catch_func_t catch)
+{
+	try_catch->test = test;
+	test_try_catch_init_internal(try_catch);
+	try_catch->try = try;
+	try_catch->catch = catch;
+}
+
+static inline void test_try_catch_run(struct test_try_catch *try_catch,
+				       void *context)
+{
+	try_catch->context = context;
+	try_catch->run(try_catch);
+}
+
+static inline void __noreturn test_try_catch_throw(
+		struct test_try_catch *try_catch)
+{
+	try_catch->throw(try_catch);
+}
+
+static inline int test_try_catch_get_result(struct test_try_catch *try_catch)
+{
+	return try_catch->try_result;
+}
+
+/*
+ * Exposed for testing only.
+ */
+void test_generic_try_catch_init(struct test_try_catch *try_catch);
+
+#endif /* _TEST_TRY_CATCH_H */
diff --git a/test/Makefile b/test/Makefile
index f840886..28cdd28c 100644
--- a/test/Makefile
+++ b/test/Makefile
@@ -1,5 +1,5 @@
 obj-$(CONFIG_TEST)		+= test.o mock.o common-mocks.o strerror.o \
-  string-stream.o test-stream.o test-executor.o
+  string-stream.o test-stream.o test-executor.o try-catch.o
 obj-$(CONFIG_TEST_TEST)		+= \
   test-mock.o mock-macro-test.o mock-test.o strerror-test.o \
   string-stream-test.o test-stream-test.o test-test.o
diff --git a/test/test-executor.c b/test/test-executor.c
index 92a5289..d117e96 100644
--- a/test/test-executor.c
+++ b/test/test-executor.c
@@ -2,15 +2,19 @@
 #include <linux/printk.h>
 #include <test/test.h>
 
-extern struct test_module *__test_modules_start[];
-extern struct test_module *__test_modules_end[];
+extern char __test_modules_start;
+extern char __test_modules_end;
 
 static bool test_run_all_tests(void)
 {
-	struct test_module** module;
+	struct test_module **module;
+	struct test_module ** const test_modules_start =
+			(struct test_module **) &__test_modules_start;
+	struct test_module ** const test_modules_end =
+			(struct test_module **) &__test_modules_end;
 	bool has_test_failed = false;
 
-	for (module = __test_modules_start; module < __test_modules_end; ++module) {
+	for (module = test_modules_start; module < test_modules_end; ++module) {
 		if (test_run_tests(*module))
 			has_test_failed = true;
 	}
@@ -26,4 +30,3 @@
 	else
 		return -EFAULT;
 }
-
diff --git a/test/test.c b/test/test.c
index fbfd8e9..1e99413 100644
--- a/test/test.c
+++ b/test/test.c
@@ -8,8 +8,8 @@
 
 #include <linux/sched.h>
 #include <linux/sched/debug.h>
-#include <os.h>
 #include <test/test.h>
+#include <test/try-catch.h>
 
 struct test_global_context {
 	struct list_head initcalls;
@@ -24,6 +24,39 @@
 	list_add_tail(&initcall->node, &test_global_context.initcalls);
 }
 
+static bool test_get_success(struct test *test)
+{
+	unsigned long flags;
+	bool success;
+
+	spin_lock_irqsave(&test->lock, flags);
+	success = test->success;
+	spin_unlock_irqrestore(&test->lock, flags);
+
+	return success;
+}
+
+static void test_set_success(struct test *test, bool success)
+{
+	unsigned long flags;
+
+	spin_lock_irqsave(&test->lock, flags);
+	test->success = success;
+	spin_unlock_irqrestore(&test->lock, flags);
+}
+
+static bool test_get_death_test(struct test *test)
+{
+	unsigned long flags;
+	bool death_test;
+
+	spin_lock_irqsave(&test->lock, flags);
+	death_test = test->death_test;
+	spin_unlock_irqrestore(&test->lock, flags);
+
+	return death_test;
+}
+
 static int test_vprintk_emit(const struct test *test,
 			     int level,
 			     const char *fmt,
@@ -52,33 +85,28 @@
 {
 	test_printk_emit(test,
 			 level[1] - '0',
-			 "kunit %s: %pV", test->name, vaf);
+			 "test %s: %pV", test->name, vaf);
 }
 
 static void test_fail(struct test *test, struct test_stream *stream)
 {
-	test->success = false;
+	test_set_success(test, false);
 	stream->set_level(stream, KERN_ERR);
 	stream->commit(stream);
 }
 
 static void __noreturn test_abort(struct test *test)
 {
-	test->death_test = true;
-	if (current->thread.fault_catcher && current->thread.is_running_test)
-		UML_LONGJMP(current->thread.fault_catcher, 1);
+	test_set_death_test(test, true);
+	test_try_catch_throw(&test->try_catch);
 
 	/*
-	 * Attempted to abort from a not properly initialized test context.
+	 * Throw could not abort from test.
+	 *
+	 * XXX: we should never reach this line! As test_try_catch_throw is
+	 * marked __noreturn.
 	 */
-	test_err(test,
-		 "Attempted to abort from a not properly initialized test context!");
-	if (!current->thread.fault_catcher)
-		test_err(test, "No fault_catcher present!");
-	if (!current->thread.is_running_test)
-		test_err(test, "is_running_test not set!");
-	show_stack(NULL, NULL);
-	BUG();
+	WARN_ONCE(true, "Throw could not abort from test!\n");
 }
 
 int test_init_test(struct test *test, const char *name)
@@ -93,6 +121,17 @@
 	return 0;
 }
 
+static void test_case_internal_cleanup(struct test *test)
+{
+	struct test_initcall *initcall;
+
+	list_for_each_entry(initcall, &test_global_context.initcalls, node) {
+		initcall->exit(initcall);
+	}
+
+	test_cleanup(test);
+}
+
 /*
  * Initializes and runs test case. Does not clean up or do post validations.
  */
@@ -107,7 +146,7 @@
 		ret = initcall->init(initcall, test);
 		if (ret) {
 			test_err(test, "failed to initialize: %d", ret);
-			test->success = false;
+			test_set_success(test, false);
 			return;
 		}
 	}
@@ -116,7 +155,7 @@
 		ret = module->init(test);
 		if (ret) {
 			test_err(test, "failed to initialize: %d", ret);
-			test->success = false;
+			test_set_success(test, false);
 			return;
 		}
 	}
@@ -124,17 +163,30 @@
 	test_case->run_case(test);
 }
 
-static void test_case_internal_cleanup(struct test *test)
+/*
+ * Handles an unexpected crash in a test case.
+ */
+static void test_handle_test_crash(struct test *test,
+				   struct test_module *module,
+				   struct test_case *test_case)
 {
-	struct test_initcall *initcall;
+	test_err(test, "%s crashed", test_case->name);
+	/*
+	 * TODO(brendanhiggins@google.com): This prints the stack trace up
+	 * through this frame, not up to the frame that caused the crash.
+	 */
+	show_stack(NULL, NULL);
 
-	list_for_each_entry(initcall, &test_global_context.initcalls, node) {
-		initcall->exit(initcall);
-	}
-
-	test_cleanup(test);
+	test_case_internal_cleanup(test);
 }
 
+struct test_try_catch_context {
+	struct test *test;
+	struct test_module *module;
+	struct test_case *test_case;
+};
+
+
 /*
  * Performs post validations and cleanup after a test case was run.
  * XXX: Should ONLY BE CALLED AFTER test_run_case_internal!
@@ -159,69 +211,50 @@
 	test_case_internal_cleanup(test);
 }
 
-/*
- * Handles an unexpected crash in a test case.
- */
-static void test_handle_test_crash(struct test *test,
-				   struct test_module *module,
-				   struct test_case *test_case)
+static void test_try_run_case(void *data)
 {
-	/*
-	 * TODO(brendanhiggins@google.com): Right now we don't have a way to
-	 * store a copy of the stack, or a copy of information from the stack,
-	 * so we need to print it in the "trap" handler; otherwise, the stack
-	 * will be destroyed when it returns to us by popping off the
-	 * appropriate stack frames (see longjmp).
-	 *
-	 * Ideally we would print the stack trace here, but we do not have the
-	 * ability to do so with meaningful information at this time.
-	 */
-	test_err(test, "%s crashed", test_case->name);
+	struct test_try_catch_context *ctx = data;
+	struct test *test = ctx->test;
+	struct test_module *module = ctx->module;
+	struct test_case *test_case = ctx->test_case;
 
-	test_case_internal_cleanup(test);
+	/*
+	 * test_run_case_internal may encounter a fatal error; if it does,
+	 * abort will be called, this thread will exit, and finally the parent
+	 * thread will resume control and handle any necessary clean up.
+	 */
+	test_run_case_internal(test, module, test_case);
+	/* This line may never be reached. */
+	test_run_case_cleanup(test, module, test_case);
 }
 
-/*
- * Performs all logic to run a test case. It also catches most errors that
- * occurs in a test case and reports them as failures.
- *
- * XXX: THIS DOES NOT FOLLOW NORMAL CONTROL FLOW. READ CAREFULLY!!!
- */
-static bool test_run_case_catch_errors(struct test *test,
-				       struct test_module *module,
-				       struct test_case *test_case)
+static void test_catch_run_case(void *data)
 {
-	jmp_buf fault_catcher;
-	int faulted;
+	struct test_try_catch_context *ctx = data;
+	struct test *test = ctx->test;
+	struct test_module *module = ctx->module;
+	struct test_case *test_case = ctx->test_case;
+	int try_exit_code = test_try_catch_get_result(&test->try_catch);
 
-	test->success = true;
-	test->death_test = false;
-
-	/*
-	 * Tell the trap subsystem that we want to catch any segfaults that
-	 * occur.
-	 */
-	current->thread.is_running_test = true;
-	current->thread.fault_catcher = &fault_catcher;
-
-	/*
-	 * ENTER HANDLER: If a failure occurs, we enter here.
-	 */
-	faulted = UML_SETJMP(&fault_catcher);
-	if (faulted == 0) {
+	if (try_exit_code) {
+		test_set_success(test, false);
 		/*
-		 * NORMAL CASE: we have not run test_run_case_internal yet.
-		 *
-		 * test_run_case_internal may encounter a fatal error; if it
-		 * does, we will jump to ENTER_HANDLER above instead of
-		 * continuing normal control flow.
+		 * Test case could not finish, we have no idea what state it is
+		 * in, so don't do clean up.
 		 */
-		test_run_case_internal(test, module, test_case);
+		if (try_exit_code == -ETIMEDOUT)
+			test_err(test, "test case timed out\n");
 		/*
-		 * This line may never be reached.
+		 * Unknown internal error occurred preventing test case from
+		 * running, so there is nothing to clean up.
 		 */
-		test_run_case_cleanup(test, module, test_case);
-	} else if (test->death_test) {
+		else
+			test_err(test, "internal error occurred preventing test case from running: %d\n",
+				  try_exit_code);
+		return;
+	}
+
+	if (test_get_death_test(test)) {
 		/*
 		 * EXPECTED DEATH: test_run_case_internal encountered
 		 * anticipated fatal error. Everything should be in a safe
@@ -235,21 +268,34 @@
 		 * the test case is in.
 		 */
 		test_handle_test_crash(test, module, test_case);
-		test->success = false;
+		test_set_success(test, false);
 	}
-	/*
-	 * EXIT HANDLER: test case has been run and all possible errors have
-	 * been handled.
-	 */
+}
 
-	/*
-	 * Tell the trap subsystem that we no longer want to catch any
-	 * segfaults.
-	 */
-	current->thread.fault_catcher = NULL;
-	current->thread.is_running_test = false;
+/*
+ * Performs all logic to run a test case. It also catches most errors that
+ * occurs in a test case and reports them as failures.
+ */
+static bool test_run_case_catch_errors(struct test *test,
+				       struct test_module *module,
+				       struct test_case *test_case)
+{
+	struct test_try_catch *try_catch = &test->try_catch;
+	struct test_try_catch_context context;
 
-	return test->success;
+	test_set_success(test, true);
+	test_set_death_test(test, false);
+
+	test_try_catch_init(try_catch,
+			    test,
+			    test_try_run_case,
+			    test_catch_run_case);
+	context.test = test;
+	context.module = module;
+	context.test_case = test_case;
+	test_try_catch_run(try_catch, &context);
+
+	return test_get_success(test);
 }
 
 int test_run_tests(struct test_module *module)
diff --git a/test/try-catch.c b/test/try-catch.c
new file mode 100644
index 0000000..ec8a3bb
--- /dev/null
+++ b/test/try-catch.c
@@ -0,0 +1,95 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * An API to allow a function, that may fail, to be executed, and recover in a
+ * controlled manner.
+ *
+ * Copyright (C) 2019, Google LLC.
+ * Author: Brendan Higgins <brendanhiggins@google.com>
+ */
+
+#include <test/try-catch.h>
+#include <test/test.h>
+#include <linux/completion.h>
+#include <linux/kthread.h>
+
+static void __noreturn test_generic_throw(struct test_try_catch *try_catch)
+{
+	try_catch->try_result = -EFAULT;
+	complete_and_exit(try_catch->try_completion, -EFAULT);
+}
+
+static int test_generic_run_threadfn_adapter(void *data)
+{
+	struct test_try_catch *try_catch = data;
+
+	try_catch->try(try_catch->context);
+
+	complete_and_exit(try_catch->try_completion, 0);
+}
+
+static void test_generic_run_try_catch(struct test_try_catch *try_catch)
+{
+	DECLARE_COMPLETION_ONSTACK(try_completion);
+	struct test *test = try_catch->test;
+	struct task_struct *task_struct;
+	int exit_code, status;
+
+	try_catch->try_completion = &try_completion;
+	try_catch->try_result = 0;
+	task_struct = kthread_run(test_generic_run_threadfn_adapter,
+				  try_catch,
+				  "test_try_catch_thread");
+	if (IS_ERR(task_struct)) {
+		try_catch->catch(try_catch->context);
+		return;
+	}
+
+	/*
+	 * TODO(brendanhiggins@google.com): We should probably have some type of
+	 * variable timeout here. The only question is what that timeout value
+	 * should be.
+	 *
+	 * The intention has always been, at some point, to be able to label
+	 * tests with some type of size bucket (unit/small, integration/medium,
+	 * large/system/end-to-end, etc), where each size bucket would get a
+	 * default timeout value kind of like what Bazel does:
+	 * https://docs.bazel.build/versions/master/be/common-definitions.html#test.size
+	 * There is still some debate to be had on exactly how we do this. (For
+	 * one, we probably want to have some sort of test runner level
+	 * timeout.)
+	 *
+	 * For more background on this topic, see:
+	 * https://mike-bland.com/2011/11/01/small-medium-large.html
+	 */
+	status = wait_for_completion_timeout(&try_completion,
+					     300 * MSEC_PER_SEC); /* 5 min */
+	if (status < 0) {
+		test_err(test, "try timed out\n");
+		try_catch->try_result = -ETIMEDOUT;
+	}
+
+	exit_code = try_catch->try_result;
+
+	if (!exit_code)
+		return;
+
+	if (exit_code == -EFAULT)
+		try_catch->try_result = 0;
+	else if (exit_code == -EINTR)
+		test_err(test, "wake_up_process() was never called\n");
+	else if (exit_code)
+		test_err(test, "Unknown error: %d\n", exit_code);
+
+	try_catch->catch(try_catch->context);
+}
+
+void test_generic_try_catch_init(struct test_try_catch *try_catch)
+{
+	try_catch->run = test_generic_run_try_catch;
+	try_catch->throw = test_generic_throw;
+}
+
+void __weak test_try_catch_init_internal(struct test_try_catch *try_catch)
+{
+	test_generic_try_catch_init(try_catch);
+}