kunit: Add compound matchers.

Implements and, or, not matchers which perform
the respective logical operations on the inner matchers.

Signed-off-by: Iurii Zaikin <yzaikin@google.com>
Change-Id: I60ba93366216855eb616db3fe10c0534ab042145
diff --git a/include/test/mock.h b/include/test/mock.h
index 3008f83..d16ced8 100644
--- a/include/test/mock.h
+++ b/include/test/mock.h
@@ -1265,6 +1265,15 @@
 					 struct mock_param_matcher *fmt_matcher,
 					 struct mock_param_matcher *va_matcher);
 
+/* Compound matchers */
+struct mock_param_matcher *and(struct test *test,
+			       struct mock_param_matcher *left_matcher,
+			       struct mock_param_matcher *right_matcher);
+struct mock_param_matcher *or(struct test *test,
+			      struct mock_param_matcher *left_matcher,
+			      struct mock_param_matcher *right_matcher);
+struct mock_param_matcher *not(struct test *test,
+			       struct mock_param_matcher *inner_matcher);
 struct mock_action *u8_return(struct test *test, u8 ret);
 struct mock_action *u16_return(struct test *test, u16 ret);
 struct mock_action *u32_return(struct test *test, u32 ret);
diff --git a/test/common-mocks.c b/test/common-mocks.c
index ed061b1..f37ada6 100644
--- a/test/common-mocks.c
+++ b/test/common-mocks.c
@@ -629,3 +629,102 @@
 
 	return &formatter->formatter;
 }
+
+/* Composite matchers: match_and, match_or. */
+
+struct mock_composite_binary_matcher {
+	struct mock_param_matcher *left_matcher;
+	struct mock_param_matcher *right_matcher;
+	struct mock_param_matcher matcher;
+};
+
+#define DEFINE_COMPOSITE_BINARY_MATCH_FUNC(combine_op, op_name)                \
+bool match_##op_name(struct mock_param_matcher *pmatcher,                      \
+		     struct test_stream *stream,                              \
+		     const void *pactual)                                      \
+{                                                                              \
+	bool result;                                                           \
+	struct mock_composite_binary_matcher *compound_matcher =               \
+			container_of(pmatcher,                                 \
+				     struct mock_composite_binary_matcher,     \
+				     matcher);                                 \
+									       \
+	stream->add(stream, "((");                                        \
+	result = compound_matcher->left_matcher->match(                        \
+	    compound_matcher->left_matcher, stream, pactual);                  \
+	stream->add(stream, ") " #op_name " ((");                         \
+	result combine_op compound_matcher->right_matcher->match(              \
+	 compound_matcher->right_matcher, stream, pactual);                    \
+	stream->add(stream, "))");                                        \
+									       \
+	return result;                                                         \
+}                                                                              \
+
+#define DEFINE_COMPOSITE_BINARY_MATCHER_FACTORY(combine_op, op_name)           \
+struct mock_param_matcher *op_name(struct test *test,                 \
+				    struct mock_param_matcher *left_matcher,  \
+				    struct mock_param_matcher *right_matcher) \
+{                                                                              \
+	struct mock_composite_binary_matcher *matcher;                         \
+									       \
+	matcher = test_kmalloc(test,                                          \
+				sizeof(*matcher),                              \
+				GFP_KERNEL);                                   \
+	if (!matcher)                                                          \
+		return NULL;                                                   \
+									       \
+	matcher->matcher.match = match_##op_name;                              \
+	matcher->left_matcher = left_matcher;                                  \
+	matcher->right_matcher = right_matcher;                                \
+	return &matcher->matcher;                                              \
+}                                                                              \
+
+#define COMPOSITE_BINARY_MATCHER(combine_op, op_name)                          \
+	DEFINE_COMPOSITE_BINARY_MATCH_FUNC(combine_op, op_name)                \
+	DEFINE_COMPOSITE_BINARY_MATCHER_FACTORY(combine_op, op_name)           \
+
+/* Conjunction of the inner matchers. */
+COMPOSITE_BINARY_MATCHER(&=, and);
+/* Disjunction of the inner matchers. */
+COMPOSITE_BINARY_MATCHER(|=, or);
+
+struct mock_composite_unary_matcher {
+	struct mock_param_matcher *inner_matcher;
+	struct mock_param_matcher matcher;
+};
+
+bool match_not(struct mock_param_matcher *pmatcher,
+	       struct test_stream *stream,
+	       const void *pactual)
+{
+	bool result;
+	struct mock_composite_unary_matcher *compound_matcher =
+			container_of(pmatcher,
+				     struct mock_composite_unary_matcher,
+				     matcher);
+
+	stream->add(stream, "not (");
+	result = !compound_matcher->inner_matcher->match(
+	    compound_matcher->inner_matcher,
+	    stream,
+	    pactual);
+	stream->add(stream, ")");
+	return result;
+}
+
+/* Negates the result of the inner matcher */
+struct mock_param_matcher *not(struct test *test,
+			       struct mock_param_matcher *inner_matcher)
+{
+	struct mock_composite_unary_matcher *matcher;
+
+	matcher = test_kmalloc(test,
+			       sizeof(*matcher),
+			       GFP_KERNEL);
+	if (!matcher)
+		return NULL;
+
+	matcher->matcher.match = match_not;
+	matcher->inner_matcher = inner_matcher;
+	return &matcher->matcher;
+}
diff --git a/test/mock-test.c b/test/mock-test.c
index 48ac0ac..2b80d37 100644
--- a/test/mock-test.c
+++ b/test/mock-test.c
@@ -800,6 +800,276 @@
 	return 0;
 }
 
+static void mock_test_and_matcher_accept(struct test *test)
+{
+	struct mock_test_context *ctx = test->priv;
+	struct MOCK(test) *mock_test = ctx->mock_test;
+	struct test *trgt = mock_get_trgt(mock_test);
+	struct mock *mock = ctx->mock;
+	const int param0 = 5;
+	static const char * const param_types[] = {"int"};
+	const void *params[] = {&param0};
+	struct mock_param_matcher *matchers[] = {
+		and(test, int_gt(test, 4), int_lt(test, 6))
+	};
+	struct mock_expectation *expectation;
+
+	const void *ret;
+
+	expectation = mock_add_matcher(mock,
+				       "",
+				       NULL,
+				       matchers,
+				       ARRAY_SIZE(matchers));
+	expectation->action = int_return(trgt, 0);
+	EXPECT_EQ(test, 0, expectation->times_called);
+
+	ret = mock->do_expect(mock,
+			      "",
+			      NULL,
+			      param_types,
+			      params,
+			      ARRAY_SIZE(params));
+	ASSERT_NOT_ERR_OR_NULL(test, ret);
+	EXPECT_EQ(test, 1, expectation->times_called);
+}
+
+
+static void mock_test_and_matcher_reject_left(struct test *test)
+{
+	struct mock_test_context *ctx = test->priv;
+	struct MOCK(test) *mock_test = ctx->mock_test;
+	struct test *trgt = mock_get_trgt(mock_test);
+	struct mock *mock = ctx->mock;
+	const int param0 = 5;
+	static const char * const param_types[] = {"int"};
+	const void *params[] = {&param0};
+	struct mock_param_matcher *matchers[] = {
+		and(test, int_gt(test, 5), int_lt(test, 6))
+	};
+	struct mock_expectation *expectation;
+	const void *ret;
+
+	expectation = mock_add_matcher(mock,
+				       "",
+				       NULL,
+				       matchers,
+				       ARRAY_SIZE(matchers));
+	expectation->action = int_return(trgt, 0);
+	EXPECT_EQ(test, 0, expectation->times_called);
+
+	ret = mock->do_expect(mock,
+			      "",
+			      NULL,
+			      param_types,
+			      params,
+			      ARRAY_SIZE(params));
+	EXPECT_FALSE(test, ret);
+	EXPECT_EQ(test, 0, expectation->times_called);
+}
+
+static void mock_test_and_matcher_reject_right(struct test *test)
+{
+	struct mock_test_context *ctx = test->priv;
+	struct MOCK(test) *mock_test = ctx->mock_test;
+	struct test *trgt = mock_get_trgt(mock_test);
+	struct mock *mock = ctx->mock;
+	const int param0 = 5;
+	static const char * const param_types[] = {"int"};
+	const void *params[] = {&param0};
+	struct mock_param_matcher *matchers[] = {
+		and(test, int_gt(test, 4), int_lt(test, 5))
+	};
+	struct mock_expectation *expectation;
+	const void *ret;
+
+	expectation = mock_add_matcher(mock,
+				       "",
+				       NULL,
+				       matchers,
+				       ARRAY_SIZE(matchers));
+	expectation->action = int_return(trgt, 0);
+	EXPECT_EQ(test, 0, expectation->times_called);
+
+	ret = mock->do_expect(mock,
+			      "",
+			      NULL,
+			      param_types,
+			      params,
+			      ARRAY_SIZE(params));
+	EXPECT_FALSE(test, ret);
+	EXPECT_EQ(test, 0, expectation->times_called);
+}
+
+static void mock_test_or_matcher_reject(struct test *test)
+{
+	struct mock_test_context *ctx = test->priv;
+	struct MOCK(test) *mock_test = ctx->mock_test;
+	struct test *trgt = mock_get_trgt(mock_test);
+	struct mock *mock = ctx->mock;
+	const int param0 = 5;
+	static const char * const param_types[] = {"int"};
+	const void *params[] = {&param0};
+	struct mock_param_matcher *matchers[] = {
+		or(test, int_lt(test, 4), int_gt(test, 6))
+	};
+	struct mock_expectation *expectation;
+
+	const void *ret;
+
+	expectation = mock_add_matcher(mock,
+				       "",
+				       NULL,
+				       matchers,
+				       ARRAY_SIZE(matchers));
+	expectation->action = int_return(trgt, 0);
+	EXPECT_EQ(test, 0, expectation->times_called);
+
+	ret = mock->do_expect(mock,
+			      "",
+			      NULL,
+			      param_types,
+			      params,
+			      ARRAY_SIZE(params));
+	EXPECT_FALSE(test, ret);
+	EXPECT_EQ(test, 0, expectation->times_called);
+}
+
+
+static void mock_test_or_matcher_accept_left(struct test *test)
+{
+	struct mock_test_context *ctx = test->priv;
+	struct MOCK(test) *mock_test = ctx->mock_test;
+	struct test *trgt = mock_get_trgt(mock_test);
+	struct mock *mock = ctx->mock;
+	const int param0 = 5;
+	static const char * const param_types[] = {"int"};
+	const void *params[] = {&param0};
+	struct mock_param_matcher *matchers[] = {
+		or(test, int_gt(test, 4), int_gt(test, 6))
+	};
+	struct mock_expectation *expectation;
+	const void *ret;
+
+	expectation = mock_add_matcher(mock,
+				       "",
+				       NULL,
+				       matchers,
+				       ARRAY_SIZE(matchers));
+	expectation->action = int_return(trgt, 0);
+	EXPECT_EQ(test, 0, expectation->times_called);
+
+	ret = mock->do_expect(mock,
+			      "",
+			      NULL,
+			      param_types,
+			      params,
+			      ARRAY_SIZE(params));
+	ASSERT_NOT_ERR_OR_NULL(test, ret);
+	EXPECT_EQ(test, 1, expectation->times_called);
+}
+
+static void mock_test_or_matcher_accept_right(struct test *test)
+{
+	struct mock_test_context *ctx = test->priv;
+	struct MOCK(test) *mock_test = ctx->mock_test;
+	struct test *trgt = mock_get_trgt(mock_test);
+	struct mock *mock = ctx->mock;
+	const int param0 = 5;
+	static const char * const param_types[] = {"int"};
+	const void *params[] = {&param0};
+	struct mock_param_matcher *matchers[] = {
+		or(test, int_lt(test, 4), int_lt(test, 6))
+	};
+	struct mock_expectation *expectation;
+	const void *ret;
+
+	expectation = mock_add_matcher(mock,
+				       "",
+				       NULL,
+				       matchers,
+				       ARRAY_SIZE(matchers));
+	expectation->action = int_return(trgt, 0);
+	EXPECT_EQ(test, 0, expectation->times_called);
+
+	ret = mock->do_expect(mock,
+			      "",
+			      NULL,
+			      param_types,
+			      params,
+			      ARRAY_SIZE(params));
+	ASSERT_NOT_ERR_OR_NULL(test, ret);
+	EXPECT_EQ(test, 1, expectation->times_called);
+}
+
+static void mock_test_not_matcher_reject(struct test *test)
+{
+	struct mock_test_context *ctx = test->priv;
+	struct MOCK(test) *mock_test = ctx->mock_test;
+	struct test *trgt = mock_get_trgt(mock_test);
+	struct mock *mock = ctx->mock;
+	const int param0 = 5;
+	static const char * const param_types[] = {"int"};
+	const void *params[] = {&param0};
+	struct mock_param_matcher *matchers[] = {
+		not(test, int_eq(test, 5))
+	};
+	struct mock_expectation *expectation;
+
+	const void *ret;
+
+	expectation = mock_add_matcher(mock,
+				       "",
+				       NULL,
+				       matchers,
+				       ARRAY_SIZE(matchers));
+	expectation->action = int_return(trgt, 0);
+	EXPECT_EQ(test, 0, expectation->times_called);
+
+	ret = mock->do_expect(mock,
+			      "",
+			      NULL,
+			      param_types,
+			      params,
+			      ARRAY_SIZE(params));
+	EXPECT_FALSE(test, ret);
+	EXPECT_EQ(test, 0, expectation->times_called);
+}
+
+
+static void mock_test_not_matcher_accept(struct test *test)
+{
+	struct mock_test_context *ctx = test->priv;
+	struct MOCK(test) *mock_test = ctx->mock_test;
+	struct test *trgt = mock_get_trgt(mock_test);
+	struct mock *mock = ctx->mock;
+	const int param0 = 5;
+	static const char * const param_types[] = {"int"};
+	const void *params[] = {&param0};
+	struct mock_param_matcher *matchers[] = {
+		not(test, int_eq(test, 100500))
+	};
+	struct mock_expectation *expectation;
+	const void *ret;
+
+	expectation = mock_add_matcher(mock,
+				       "",
+				       NULL,
+				       matchers,
+				       ARRAY_SIZE(matchers));
+	expectation->action = int_return(trgt, 0);
+	EXPECT_EQ(test, 0, expectation->times_called);
+
+	ret = mock->do_expect(mock,
+			      "",
+			      NULL,
+			      param_types,
+			      params,
+			      ARRAY_SIZE(params));
+	ASSERT_NOT_ERR_OR_NULL(test, ret);
+	EXPECT_EQ(test, 1, expectation->times_called);
+}
+
 static struct test_case mock_test_cases[] = {
 	TEST_CASE(mock_test_do_expect_basic),
 	TEST_CASE(mock_test_ptr_eq),
@@ -819,13 +1089,21 @@
 	TEST_CASE(mock_test_in_sequence_abc_success),
 	TEST_CASE(mock_test_in_sequence_bac_success),
 	TEST_CASE(mock_test_in_sequence_no_a_fail),
-        TEST_CASE(mock_test_in_sequence_retire_on_saturation),
-        TEST_CASE(mock_test_atleast),
-        TEST_CASE(mock_test_atleast_fail),
-        TEST_CASE(mock_test_atmost),
-        TEST_CASE(mock_test_atmost_fail),
-        TEST_CASE(mock_test_between),
-        TEST_CASE(mock_test_between_fail),
+	TEST_CASE(mock_test_in_sequence_retire_on_saturation),
+	TEST_CASE(mock_test_atleast),
+	TEST_CASE(mock_test_atleast_fail),
+	TEST_CASE(mock_test_atmost),
+	TEST_CASE(mock_test_atmost_fail),
+	TEST_CASE(mock_test_between),
+	TEST_CASE(mock_test_between_fail),
+	TEST_CASE(mock_test_and_matcher_accept),
+	TEST_CASE(mock_test_and_matcher_reject_left),
+	TEST_CASE(mock_test_and_matcher_reject_right),
+	TEST_CASE(mock_test_or_matcher_reject),
+	TEST_CASE(mock_test_or_matcher_accept_left),
+	TEST_CASE(mock_test_or_matcher_accept_right),
+	TEST_CASE(mock_test_not_matcher_reject),
+	TEST_CASE(mock_test_not_matcher_accept),
         {},
 };