abbel/packet/packet.go

135 lines
2.7 KiB
Go

package packet
import (
"context"
"fmt"
"net"
"net/netip"
"syscall"
"golang.org/x/net/ipv6"
)
const (
BabelMagic = 42
BabelVersion = 2
BabelPacketHeaderSize = 4
)
type Conn struct {
v6pc *ipv6.PacketConn
}
func Listen(group string, port uint16, ifaces ...string) (Conn, error) {
conn, err := ListenPort(port)
if err != nil {
return Conn{}, err
}
if len(ifaces) == 1 && ifaces[0] == "any" {
netifs, err := net.Interfaces()
if err != nil {
return Conn{}, err
}
ifaces = make([]string, 0, len(netifs))
for _, netif := range netifs {
ifaces = append(ifaces, netif.Name)
}
}
for _, iface := range ifaces {
if err = conn.JoinGroup(iface, group); err != nil {
conn.Close()
return Conn{}, err
}
}
return conn, nil
}
func ListenPort(port uint16) (Conn, error) {
var c Conn
lc := net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
var err error
c.Control(func(fd uintptr) {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
})
return err
},
}
uc, err := lc.ListenPacket(context.Background(), "udp6", fmt.Sprintf(":%d", port))
if err != nil {
return c, err
}
c.v6pc = ipv6.NewPacketConn(uc)
return c, c.v6pc.SetControlMessage(
ipv6.FlagDst|ipv6.FlagTrafficClass|ipv6.FlagHopLimit|ipv6.FlagPathMTU,
true)
}
func (c Conn) Close() error {
return c.v6pc.Close()
}
func (c Conn) JoinGroup(ifname string, addr string) error {
ifi, err := net.InterfaceByName(ifname)
if err != nil {
return fmt.Errorf("InterfaceByName(%s): %w", ifname, err)
}
ip, err := netip.ParseAddr(addr)
if err != nil {
return err
}
return c.v6pc.JoinGroup(ifi, &net.UDPAddr{IP: ip.AsSlice()})
}
func Validate(b []byte) ([]byte, error) {
n := len(b)
if n < BabelPacketHeaderSize {
return nil, fmt.Errorf("Packet too short: %d", len(b))
}
magic := b[0]
version := b[1]
length := uint16(b[2])<<8 + uint16(b[3])
b = b[4:]
if magic != BabelMagic {
return nil, fmt.Errorf("Invalid magic number %d", magic)
}
if version != BabelVersion {
return nil, fmt.Errorf("Unsupported version number %d", version)
}
if int(length) > len(b) {
return nil, fmt.Errorf("Invalid length for packet of size %d: %d", n, length)
}
return b, nil
}
func (c Conn) ReadFrom(b []byte) (body []byte, src netip.Addr, ifindex int, err error) {
n, rcm, _, err := c.v6pc.ReadFrom(b)
if err != nil {
return nil, netip.Addr{}, 0, err
}
b, err = Validate(b[:n])
if err != nil {
return nil, netip.Addr{}, 0, err
}
var ok bool
src, ok = netip.AddrFromSlice(rcm.Src)
if !ok {
return nil, netip.Addr{}, 0, fmt.Errorf("Invalid src address %q", rcm.Src)
}
return b, src, rcm.IfIndex, err
}