kunit_tool: refactored kunit_tool for adding subcommands

Cleaned up kunit_tool a bit to make adding subcommands easier. Also
added some tests.

Google-Bug-Id: 117125357
Signed-off-by: Brendan Higgins <brendanhiggins@google.com>
Change-Id: Ic032bb64c33f6f9a70b51a191caa7019eecdff02
diff --git a/tools/testing/kunit/kunit.py b/tools/testing/kunit/kunit.py
index 5421af3..938b0da 100755
--- a/tools/testing/kunit/kunit.py
+++ b/tools/testing/kunit/kunit.py
@@ -12,20 +12,7 @@
 import kunit_kernel
 import kunit_parser
 
-parser = argparse.ArgumentParser(description='Runs KUnit tests.')
-
-parser.add_argument('--raw_output', help='don\'t format output from kernel',
-		    action='store_true')
-
-parser.add_argument('--timeout', help='maximum number of seconds to allow for '
-		    'all tests to run. This does not include time taken to '
-		    'build the tests.', type=int, default=300,
-		    metavar='timeout')
-
-cli_args = parser.parse_args()
-
-def main(linux):
-
+def run_tests(cli_args, linux):
 	config_start = time.time()
 	success = linux.build_reconfig()
 	config_end = time.time()
@@ -49,16 +36,43 @@
 	else:
 		for line in kunit_parser.parse_run_tests(
 			kunit_parser.isolate_kunit_output(
-				linux.run_kernel(timeout=cli_args.timeout))):
+				linux.run_kernel(
+					timeout=cli_args.timeout))):
 			print(line)
 
 	test_end = time.time()
 
 	print(kunit_parser.timestamp((
 		'Elapsed time: %.3fs total, %.3fs configuring, %.3fs ' +
-		'building, %.3fs running.\n') % (test_end - config_start,
-		config_end - config_start, build_end - build_start,
-		test_end - test_start)))
+		'building, %.3fs running.\n') % (
+				test_end - config_start,
+				config_end - config_start,
+				build_end - build_start,
+				test_end - test_start)))
+
+def main(argv, linux=kunit_kernel.LinuxSourceTree()):
+	parser = argparse.ArgumentParser(
+			description='Helps writing and running KUnit tests.')
+	subparser = parser.add_subparsers(dest='subcommand')
+
+	run_parser = subparser.add_parser('run', help='Runs KUnit tests.')
+	run_parser.add_argument('--raw_output', help='don\'t format output from kernel',
+				action='store_true')
+
+	run_parser.add_argument('--timeout',
+				help='maximum number of seconds to allow for all tests '
+				'to run. This does not include time taken to build the '
+				'tests.',
+				type=int,
+				default=300,
+				metavar='timeout')
+
+	cli_args = parser.parse_args(argv)
+
+	if cli_args.subcommand == 'run':
+		run_tests(cli_args, linux)
+	else:
+		parser.print_help()
 
 if __name__ == '__main__':
-	main(kunit_kernel.LinuxSourceTree())
\ No newline at end of file
+	main(sys.argv[1:])
diff --git a/tools/testing/kunit/kunit_test.py b/tools/testing/kunit/kunit_test.py
index 65729a3..816ba07 100755
--- a/tools/testing/kunit/kunit_test.py
+++ b/tools/testing/kunit/kunit_test.py
@@ -1,6 +1,7 @@
 #!/usr/bin/python3
 
 import unittest
+from unittest import mock
 
 import tempfile, shutil # Handling test_tmpdir
 
@@ -160,5 +161,51 @@
 			result)
 		file.close()
 
+class StrContains(str):
+	def __eq__(self, other):
+		return self in other
+
+class KUnitMainTest(unittest.TestCase):
+	def setUp(self):
+		self.print_patch = mock.patch('builtins.print')
+		self.print_mock = self.print_patch.start()
+		self.linux_source_mock = mock.Mock()
+		self.linux_source_mock.build_reconfig = mock.Mock()
+		self.linux_source_mock.run_kernel = mock.Mock(return_value=[
+				'console 0 enabled',
+				'List of all partitions:'])
+
+	def tearDown(self):
+		self.print_patch.stop()
+
+	def test_run_passes_args_pass(self):
+		kunit.main(['run'], self.linux_source_mock)
+		assert self.linux_source_mock.build_reconfig.call_count == 1
+		assert self.linux_source_mock.run_kernel.call_count == 1
+		self.print_mock.assert_any_call(StrContains('Testing complete.'))
+
+	def test_run_passes_args_fail(self):
+		self.linux_source_mock.run_kernel = mock.Mock(return_value=[])
+		kunit.main(['run'], self.linux_source_mock)
+		assert self.linux_source_mock.build_reconfig.call_count == 1
+		assert self.linux_source_mock.run_kernel.call_count == 1
+		self.print_mock.assert_any_call(StrContains('Before the crash:'))
+
+	def test_run_raw_output(self):
+		self.linux_source_mock.run_kernel = mock.Mock(return_value=[])
+		kunit.main(['run', '--raw_output'], self.linux_source_mock)
+		assert self.linux_source_mock.build_reconfig.call_count == 1
+		assert self.linux_source_mock.run_kernel.call_count == 1
+		for kall in self.print_mock.call_args_list:
+			assert kall != mock.call(StrContains('Testing complete.'))
+			assert kall != mock.call(StrContains('Before the crash:'))
+
+	def test_run_timeout(self):
+		timeout = 3453
+		kunit.main(['run', '--timeout', str(timeout)], self.linux_source_mock)
+		assert self.linux_source_mock.build_reconfig.call_count == 1
+		self.linux_source_mock.run_kernel.assert_called_once_with(timeout=timeout)
+		self.print_mock.assert_any_call(StrContains('Testing complete.'))
+
 if __name__ == '__main__':
 	unittest.main()