diff --git a/main.go b/main.go index 46f2637..2ea3f74 100644 --- a/main.go +++ b/main.go @@ -15,7 +15,8 @@ import ( "golang.org/x/sys/unix" - "github.com/vishvananda/netlink" + "github.com/jsimonetti/rtnetlink" + "github.com/mdlayher/netlink" ) var ( @@ -65,25 +66,31 @@ func newRouteStats() *routeStats { } } -func (rs *routeStats) update(ru netlink.RouteUpdate) { - rs.Lock() - defer rs.Unlock() - - switch ru.Type { - case unix.RTM_NEWROUTE: - rs.add(ru) - case unix.RTM_DELROUTE: - rs.del(ru) +func DstIPNet(rm rtnetlink.RouteMessage) *net.IPNet { + var zeros int + switch rm.Family { + case unix.AF_INET: + zeros = 32 + case unix.AF_INET6: + zeros = 128 default: - fmt.Fprintf(os.Stderr, "Unknown route type %d\n", ru.Type) + fmt.Fprintf(os.Stderr, "unexpected family %q", rm.Family) + } + return &net.IPNet{ + IP: rm.Attributes.Dst, + Mask: net.CIDRMask(int(rm.DstLength), zeros), } } -func (rs *routeStats) add(ru netlink.RouteUpdate) { - key := ru.Route.Dst.String() +func (rs *routeStats) add(rm rtnetlink.RouteMessage) { + rs.Lock() + defer rs.Unlock() + + dst := DstIPNet(rm) + key := dst.String() r := rs.stats[key] if r == nil { - r = &route{Dst: ru.Route.Dst} + r = &route{Dst: dst} } r.Counter++ if r.UnreachableSince == nil { @@ -93,11 +100,15 @@ func (rs *routeStats) add(ru netlink.RouteUpdate) { rs.stats[key] = r } -func (rs *routeStats) del(ru netlink.RouteUpdate) { - key := ru.Route.Dst.String() +func (rs *routeStats) del(rm rtnetlink.RouteMessage) { + rs.Lock() + defer rs.Unlock() + + dst := DstIPNet(rm) + key := dst.String() r := rs.stats[key] if r == nil { - r = &route{Dst: ru.Route.Dst} + r = &route{Dst: dst} return } r.UnreachableDuration += time.Since(*r.UnreachableSince) @@ -129,25 +140,44 @@ func (rs *routeStats) getLongest() []route { } func monitor(done <-chan struct{}, rs *routeStats) { - rups := make(chan netlink.RouteUpdate, 64) - opts := netlink.RouteSubscribeOptions{ - ErrorCallback: func(err error) { - fmt.Fprintf(os.Stderr, "%s\n", err, err) - }, - } - if err := netlink.RouteSubscribeWithOptions(rups, done, opts); err != nil { + c, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ + Groups: 1<<(unix.RTNLGRP_IPV4_ROUTE-1) | 1<<(unix.RTNLGRP_IPV6_ROUTE-1), + }) + if err != nil { log.Fatal(err) } - for ru := range rups { - if ru.Route.Type != unix.RTN_UNREACHABLE { - continue + go func() { + <-done + c.Close() + }() + + if err = c.SetReadBuffer(16 * 1048576); err != nil { + log.Fatal(err) + } + + for { + ms, err := c.Receive() + if err != nil { + log.Println(err) } - // show unly babel routes - if ru.Route.Protocol != 42 { - continue + + for _, m := range ms { + rm := rtnetlink.RouteMessage{} + if err = rm.UnmarshalBinary(m.Data); err != nil { + log.Fatal(err) + } + if rm.Type != unix.RTN_UNREACHABLE || + rm.Protocol != 42 { + continue + } + switch m.Header.Type { + case unix.RTM_NEWROUTE: + rs.add(rm) + case unix.RTM_DELROUTE: + rs.del(rm) + } } - rs.update(ru) } }