lwtunnel: change prototype of lwtunnel_state_get()

It saves some lines and simplify a bit the code when the state is returning
by this function. It's also useful to handle a NULL entry.

To avoid too long lines, I've also renamed lwtunnel_state_get() and
lwtunnel_state_put() to lwtstate_get() and lwtstate_put().

CC: Thomas Graf <tgraf@suug.ch>
CC: Roopa Prabhu <roopa@cumulusnetworks.com>
Signed-off-by: Nicolas Dichtel <nicolas.dichtel@6wind.com>
Acked-by: Thomas Graf <tgraf@suug.ch>
Acked-by: Roopa Prabhu <roopa@cumulusnetworks.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/include/net/lwtunnel.h b/include/net/lwtunnel.h
index 918e03c..b020390 100644
--- a/include/net/lwtunnel.h
+++ b/include/net/lwtunnel.h
@@ -35,12 +35,16 @@
 		lwtun_encaps[LWTUNNEL_ENCAP_MAX+1];
 
 #ifdef CONFIG_LWTUNNEL
-static inline void lwtunnel_state_get(struct lwtunnel_state *lws)
+static inline struct lwtunnel_state *
+lwtstate_get(struct lwtunnel_state *lws)
 {
-	atomic_inc(&lws->refcnt);
+	if (lws)
+		atomic_inc(&lws->refcnt);
+
+	return lws;
 }
 
-static inline void lwtunnel_state_put(struct lwtunnel_state *lws)
+static inline void lwtstate_put(struct lwtunnel_state *lws)
 {
 	if (!lws)
 		return;
@@ -74,11 +78,13 @@
 
 #else
 
-static inline void lwtunnel_state_get(struct lwtunnel_state *lws)
+static inline struct lwtunnel_state *
+lwtstate_get(struct lwtunnel_state *lws)
 {
+	return lws;
 }
 
-static inline void lwtunnel_state_put(struct lwtunnel_state *lws)
+static inline void lwtstate_put(struct lwtunnel_state *lws)
 {
 }
 
diff --git a/net/ipv4/fib_semantics.c b/net/ipv4/fib_semantics.c
index d4c6732..65e0039 100644
--- a/net/ipv4/fib_semantics.c
+++ b/net/ipv4/fib_semantics.c
@@ -209,7 +209,7 @@
 	change_nexthops(fi) {
 		if (nexthop_nh->nh_dev)
 			dev_put(nexthop_nh->nh_dev);
-		lwtunnel_state_put(nexthop_nh->nh_lwtstate);
+		lwtstate_put(nexthop_nh->nh_lwtstate);
 		free_nh_exceptions(nexthop_nh);
 		rt_fibinfo_free_cpus(nexthop_nh->nh_pcpu_rth_output);
 		rt_fibinfo_free(&nexthop_nh->nh_rth_input);
@@ -514,8 +514,8 @@
 							   nla, &lwtstate);
 				if (ret)
 					goto errout;
-				lwtunnel_state_get(lwtstate);
-				nexthop_nh->nh_lwtstate = lwtstate;
+				nexthop_nh->nh_lwtstate =
+					lwtstate_get(lwtstate);
 			}
 		}
 
@@ -971,8 +971,7 @@
 			if (err)
 				goto failure;
 
-			lwtunnel_state_get(lwtstate);
-			nh->nh_lwtstate = lwtstate;
+			nh->nh_lwtstate = lwtstate_get(lwtstate);
 		}
 		nh->nh_oif = cfg->fc_oif;
 		nh->nh_gw = cfg->fc_gw;
diff --git a/net/ipv4/route.c b/net/ipv4/route.c
index 519ec23..1109639 100644
--- a/net/ipv4/route.c
+++ b/net/ipv4/route.c
@@ -1358,7 +1358,7 @@
 		list_del(&rt->rt_uncached);
 		spin_unlock_bh(&ul->lock);
 	}
-	lwtunnel_state_put(rt->rt_lwtstate);
+	lwtstate_put(rt->rt_lwtstate);
 }
 
 void rt_flush_dev(struct net_device *dev)
@@ -1407,12 +1407,7 @@
 #ifdef CONFIG_IP_ROUTE_CLASSID
 		rt->dst.tclassid = nh->nh_tclassid;
 #endif
-		if (nh->nh_lwtstate) {
-			lwtunnel_state_get(nh->nh_lwtstate);
-			rt->rt_lwtstate = nh->nh_lwtstate;
-		} else {
-			rt->rt_lwtstate = NULL;
-		}
+		rt->rt_lwtstate = lwtstate_get(nh->nh_lwtstate);
 		if (unlikely(fnhe))
 			cached = rt_bind_exception(rt, fnhe, daddr);
 		else if (!(rt->dst.flags & DST_NOCACHE))
diff --git a/net/ipv6/ip6_fib.c b/net/ipv6/ip6_fib.c
index d715f2e..5693b5e 100644
--- a/net/ipv6/ip6_fib.c
+++ b/net/ipv6/ip6_fib.c
@@ -178,7 +178,7 @@
 static void rt6_release(struct rt6_info *rt)
 {
 	if (atomic_dec_and_test(&rt->rt6i_ref)) {
-		lwtunnel_state_put(rt->rt6i_lwtstate);
+		lwtstate_put(rt->rt6i_lwtstate);
 		rt6_free_pcpu(rt);
 		dst_free(&rt->dst);
 	}
diff --git a/net/ipv6/route.c b/net/ipv6/route.c
index fbe27fb..c9b2b9f 100644
--- a/net/ipv6/route.c
+++ b/net/ipv6/route.c
@@ -1778,8 +1778,7 @@
 					   cfg->fc_encap, &lwtstate);
 		if (err)
 			goto out;
-		lwtunnel_state_get(lwtstate);
-		rt->rt6i_lwtstate = lwtstate;
+		rt->rt6i_lwtstate = lwtstate_get(lwtstate);
 		if (lwtunnel_output_redirect(rt->rt6i_lwtstate))
 			rt->dst.output = lwtunnel_output6;
 	}
@@ -2161,10 +2160,7 @@
 #endif
 	rt->rt6i_prefsrc = ort->rt6i_prefsrc;
 	rt->rt6i_table = ort->rt6i_table;
-	if (ort->rt6i_lwtstate) {
-		lwtunnel_state_get(ort->rt6i_lwtstate);
-		rt->rt6i_lwtstate = ort->rt6i_lwtstate;
-	}
+	rt->rt6i_lwtstate = lwtstate_get(ort->rt6i_lwtstate);
 }
 
 #ifdef CONFIG_IPV6_ROUTE_INFO