2021-12-14 10:17:05 +01:00
|
|
|
package packet
|
|
|
|
|
|
|
|
import (
|
2023-03-07 01:08:21 +01:00
|
|
|
"context"
|
2021-12-14 10:17:05 +01:00
|
|
|
"fmt"
|
|
|
|
"net"
|
2023-03-07 01:08:21 +01:00
|
|
|
"syscall"
|
2021-12-14 10:17:05 +01:00
|
|
|
|
|
|
|
"golang.org/x/net/ipv6"
|
|
|
|
"inet.af/netaddr"
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
BabelMagic = 42
|
|
|
|
BabelVersion = 2
|
|
|
|
BabelPacketHeaderSize = 4
|
|
|
|
)
|
|
|
|
|
|
|
|
type Conn struct {
|
|
|
|
v6pc *ipv6.PacketConn
|
|
|
|
}
|
|
|
|
|
2021-12-14 15:23:45 +01:00
|
|
|
func Listen(group string, port uint16, ifaces ...string) (Conn, error) {
|
|
|
|
conn, err := ListenPort(port)
|
|
|
|
if err != nil {
|
|
|
|
return Conn{}, err
|
|
|
|
}
|
2023-03-07 01:09:52 +01:00
|
|
|
if len(ifaces) == 1 && ifaces[0] == "any" {
|
|
|
|
netifs, err := net.Interfaces()
|
|
|
|
if err != nil {
|
|
|
|
return Conn{}, err
|
|
|
|
}
|
|
|
|
ifaces = make([]string, 0, len(netifs))
|
|
|
|
for _, netif := range netifs {
|
|
|
|
ifaces = append(ifaces, netif.Name)
|
|
|
|
}
|
|
|
|
}
|
2021-12-14 15:23:45 +01:00
|
|
|
for _, iface := range ifaces {
|
|
|
|
if err = conn.JoinGroup(iface, group); err != nil {
|
|
|
|
conn.Close()
|
|
|
|
return Conn{}, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return conn, nil
|
|
|
|
}
|
2021-12-14 10:17:05 +01:00
|
|
|
func ListenPort(port uint16) (Conn, error) {
|
|
|
|
var c Conn
|
|
|
|
|
2023-03-07 01:08:21 +01:00
|
|
|
lc := net.ListenConfig{
|
|
|
|
Control: func(network, address string, c syscall.RawConn) error {
|
|
|
|
var err error
|
|
|
|
c.Control(func(fd uintptr) {
|
|
|
|
err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
|
|
|
|
})
|
|
|
|
return err
|
|
|
|
},
|
|
|
|
}
|
|
|
|
uc, err := lc.ListenPacket(context.Background(), "udp6", fmt.Sprintf(":%d", port))
|
|
|
|
|
2021-12-14 10:17:05 +01:00
|
|
|
if err != nil {
|
|
|
|
return c, err
|
|
|
|
}
|
2023-03-07 01:08:21 +01:00
|
|
|
|
2021-12-14 10:17:05 +01:00
|
|
|
c.v6pc = ipv6.NewPacketConn(uc)
|
|
|
|
|
2021-12-14 15:23:45 +01:00
|
|
|
return c, c.v6pc.SetControlMessage(
|
|
|
|
ipv6.FlagDst|ipv6.FlagTrafficClass|ipv6.FlagHopLimit|ipv6.FlagPathMTU,
|
|
|
|
true)
|
2021-12-14 10:17:05 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func (c Conn) Close() error {
|
|
|
|
return c.v6pc.Close()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c Conn) JoinGroup(ifname string, addr string) error {
|
|
|
|
ifi, err := net.InterfaceByName(ifname)
|
|
|
|
if err != nil {
|
2023-03-07 01:09:52 +01:00
|
|
|
return fmt.Errorf("InterfaceByName(%s): %w", ifname, err)
|
2021-12-14 10:17:05 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
ip, err := netaddr.ParseIP(addr)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return c.v6pc.JoinGroup(ifi, &net.UDPAddr{IP: ip.IPAddr().IP})
|
|
|
|
}
|
|
|
|
|
2021-12-15 11:20:16 +01:00
|
|
|
func Validate(b []byte) ([]byte, error) {
|
|
|
|
n := len(b)
|
2021-12-14 10:17:05 +01:00
|
|
|
|
2021-12-15 11:20:16 +01:00
|
|
|
if n < BabelPacketHeaderSize {
|
|
|
|
return nil, fmt.Errorf("Packet too short: %d", len(b))
|
2021-12-14 10:17:05 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
magic := b[0]
|
|
|
|
version := b[1]
|
|
|
|
length := uint16(b[2])<<8 + uint16(b[3])
|
|
|
|
b = b[4:]
|
|
|
|
|
|
|
|
if magic != BabelMagic {
|
2021-12-15 11:20:16 +01:00
|
|
|
return nil, fmt.Errorf("Invalid magic number %d", magic)
|
2021-12-14 10:17:05 +01:00
|
|
|
}
|
|
|
|
if version != BabelVersion {
|
2021-12-15 11:20:16 +01:00
|
|
|
return nil, fmt.Errorf("Unsupported version number %d", version)
|
2021-12-14 10:17:05 +01:00
|
|
|
}
|
|
|
|
if int(length) > len(b) {
|
2021-12-15 11:20:16 +01:00
|
|
|
return nil, fmt.Errorf("Invalid length for packet of size %d: %d", n, length)
|
|
|
|
}
|
|
|
|
|
|
|
|
return b, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c Conn) ReadFrom(b []byte) (body []byte, src netaddr.IP, ifindex int, err error) {
|
|
|
|
n, rcm, _, err := c.v6pc.ReadFrom(b)
|
|
|
|
if err != nil {
|
|
|
|
return nil, netaddr.IP{}, 0, err
|
|
|
|
}
|
|
|
|
|
|
|
|
b, err = Validate(b[:n])
|
|
|
|
if err != nil {
|
|
|
|
return nil, netaddr.IP{}, 0, err
|
2021-12-14 10:17:05 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
var ok bool
|
|
|
|
src, ok = netaddr.FromStdIPRaw(rcm.Src)
|
|
|
|
|
|
|
|
if !ok {
|
2021-12-15 11:20:16 +01:00
|
|
|
return nil, netaddr.IP{}, 0, fmt.Errorf("Invalid src address %q", rcm.Src)
|
2021-12-14 10:17:05 +01:00
|
|
|
}
|
|
|
|
|
2021-12-15 11:20:16 +01:00
|
|
|
return b, src, rcm.IfIndex, err
|
2021-12-14 10:17:05 +01:00
|
|
|
|
|
|
|
}
|