diff --git a/main.go b/main.go index 7b7be47..506ff66 100644 --- a/main.go +++ b/main.go @@ -2,17 +2,203 @@ package main import ( "encoding/binary" + "errors" "flag" "fmt" "log" "math" "net" "net/netip" + "sync" + "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.org/x/sys/unix" ) -func vx46(natprefix netip.Addr, upstreamAddr netip.Addr, port uint16, mtu uint16) error { +const ( + NUM_MSGS = 64 + NUM_BUFFERS = 1 +) + +var ( + OOB_SIZE = unix.CmsgSpace(16 + 2) // ipv6 address + port +) + +func (vx *vx46) transform46(msgs4 []ipv4.Message, msgs6 []ipv6.Message) error { + // embed the "client" ipv4 into the src address for the packet to the upstream vxlan ipv6 endpoint + // the destination is the upstream vxlan endpoint + for i := range msgs4 { + cm4 := ipv4.ControlMessage{} + if err := cm4.Parse(msgs4[i].OOB[:msgs4[i].NN]); err != nil { + return err + } + + inUDPAddr4 := msgs4[i].Addr.(*net.UDPAddr) + egressSrcAddr := vx.natprefix.As16() + copy(egressSrcAddr[10:14], inUDPAddr4.IP.To4()[:4]) // panics if To4 returns nil + binary.BigEndian.PutUint16(egressSrcAddr[14:16], uint16(inUDPAddr4.Port)) + + if vx.oobkey != egressSrcAddr { + cm6 := ipv6.ControlMessage{Src: net.IP(egressSrcAddr[:])} + vx.oobcache = cm6.Marshal() + } + + msgs6[i].Buffers[0] = msgs4[i].Buffers[0][:msgs4[i].N] + msgs6[i].OOB = vx.oobcache + msgs6[i].Addr = vx.upstream + msgs6[i].N = msgs4[i].N + msgs6[i].NN = len(msgs6[i].OOB) + } + + return nil +} + +func (vx *vx46) forward46() error { + //msgs4, msgs6 := prepareMsgs(vx.mtu) + msgs4 := [NUM_MSGS]ipv4.Message{} + msgs6 := [NUM_MSGS]ipv6.Message{} + + for i := range msgs6 { + msgs4[i].Buffers = [][]byte{make([]byte, vx.mtu)} + msgs6[i].Buffers = [][]byte{nil} + } + + for { + n, err := vx.pc4.ReadBatch(msgs4[:], 0) + if n > NUM_MSGS*4/3 { + log.Printf("forward46: %d in", n) + } + if err != nil { + return err + } + err = vx.transform46(msgs4[:n], msgs6[:n]) + if err != nil { + return err + } + outn, err := vx.pc6.WriteBatch(msgs6[:n], 0) + if err != nil { + return err + } + if outn != n { + return fmt.Errorf("Dropped messages. Sent %d of %d.", outn, n) + } + } +} + +func (vx *vx46) transform64(msgs6 []ipv6.Message, msgs4 []ipv4.Message) error { + // embed the "client" ipv4 into the src address for the packet to the upstream vxlan ipv6 endpoint + // the destination is the upstream vxlan endpoint + for i := range msgs6 { + cm6 := ipv6.ControlMessage{} + if err := cm6.Parse(msgs6[i].OOB[:msgs6[i].NN]); err != nil { + return err + } + if cm6.Dst.To16() == nil { + return fmt.Errorf("Destination information not available") + } + + msgs4[i].Buffers[0] = msgs6[i].Buffers[0][:msgs6[i].N] + if msgs4[i].Addr == nil { + msgs4[i].Addr = &net.UDPAddr{} + } + addr := msgs4[i].Addr.(*net.UDPAddr) + addr.IP = cm6.Dst[10:14] + addr.Port = int(binary.BigEndian.Uint16(cm6.Dst[14:16])) + msgs4[i].N = msgs6[i].N + msgs4[i].OOB = nil + msgs4[i].NN = 0 + } + + return nil +} + +func (vx *vx46) forward64() error { + //msgs4, msgs6 := prepareMsgs(vx.mtu) + msgs4 := [NUM_MSGS]ipv4.Message{} + msgs6 := [NUM_MSGS]ipv6.Message{} + + for i := range msgs6 { + msgs6[i].Buffers = [][]byte{make([]byte, vx.mtu)} + msgs6[i].OOB = make([]byte, OOB_SIZE) + msgs4[i].Buffers = [][]byte{nil} + } + + for { + n, err := vx.pc6.ReadBatch(msgs6[:], 0) + if n > NUM_MSGS*4/3 { + log.Printf("forward64: %d in", n) + } + if err != nil { + return err + } + err = vx.transform64(msgs6[:n], msgs4[:n]) + if err != nil { + return err + } + outn, err := vx.pc4.WriteBatch(msgs4[:n], 0) + if err != nil { + return err + } + if outn != n { + return fmt.Errorf("Dropped messages. Sent %d of %d.", outn, n) + } + } +} + +func (vx *vx46) forward() error { + l4, err := net.ListenUDP("udp4", &net.UDPAddr{Port: int(vx.port)}) + if err != nil { + return err + } + + l6, err := net.ListenUDP("udp6", &net.UDPAddr{Port: int(vx.port)}) + if err != nil { + return err + } + + vx.pc4 = ipv4.NewPacketConn(l4) + vx.pc6 = ipv6.NewPacketConn(l6) + + vx.pc6.SetControlMessage(ipv6.FlagDst, true) + + var wg sync.WaitGroup + var err4 error + var err6 error + + wg.Add(1) + go func() { + err4 = vx.forward46() + l6.Close() + wg.Done() + }() + wg.Add(1) + go func() { + err6 = vx.forward64() + l4.Close() + wg.Done() + }() + + wg.Wait() + + return errors.Join(err4, err6) +} + +type vx46 struct { + pc4 *ipv4.PacketConn + pc6 *ipv6.PacketConn + natprefix netip.Addr + upstreamAddr netip.Addr + upstream *net.UDPAddr + port uint16 + mtu uint16 + buffers int + + oobkey [16]byte + oobcache []byte +} + +func vx46forward(natprefix netip.Addr, upstreamAddr netip.Addr, port uint16, mtu uint16) error { upstream := netip.AddrPortFrom(upstreamAddr, port) p, err := net.ListenUDP("udp", &net.UDPAddr{Port: int(port)}) if err != nil { @@ -25,8 +211,8 @@ func vx46(natprefix netip.Addr, upstreamAddr netip.Addr, port uint16, mtu uint16 defer p.Close() - b := make([]byte, mtu+14+8) // inner ethernet header + vxlan header - var oob [20480]byte // from /proc/sys/net/core/optmem_max + b := make([]byte, mtu) // inner ethernet header + vxlan header + oob := make([]byte, OOB_SIZE) for { n, oobn, _, ingressSrcAddrPort, err := p.ReadMsgUDPAddrPort(b[:], oob[:]) @@ -81,8 +267,9 @@ func vx46(natprefix netip.Addr, upstreamAddr netip.Addr, port uint16, mtu uint16 func main() { natprefixStr := flag.String("prefix", "", "local IPv6 base address for a /80 to use for communication with upstream") upstreamStr := flag.String("upstream", "", "IPv6 address of the upstream VXLAN endpoint") - portInt := flag.Uint("port", 8472, "port for vxlan communication") + portInt := flag.Int("port", 8472, "port for vxlan communication") mtuInt := flag.Uint("mtu", 1422, "buffer size") + buffersInt := flag.Int("buffers", 64, "number of buffers for I/O batching") flag.Parse() @@ -90,7 +277,7 @@ func main() { if err != nil { log.Fatalf("Invalid prefix: %s", err) } - upstream, err := netip.ParseAddr(*upstreamStr) + upstreamAddr, err := netip.ParseAddr(*upstreamStr) if err != nil { log.Fatalf("Invalid upstream: %s", err) } @@ -100,11 +287,27 @@ func main() { port := uint16(*portInt) if *mtuInt > math.MaxUint16 { - log.Fatalf("mtu out of range: %d", *portInt) + log.Fatalf("mtu out of range: %d", *mtuInt) } - mtu := uint16(*mtuInt) + mtu := uint16(*mtuInt + 14 + 8) // inner ethernet header + vxlan header - if err := vx46(natprefix, upstream, port, mtu); err != nil { - log.Println(err) + if *buffersInt < 1 { + log.Fatalf("buffers < 1: %d", *mtuInt) + } + + if *buffersInt == 1 { + log.Println(vx46forward(natprefix, upstreamAddr, port, mtu)) + } else { + vx := vx46{ + natprefix: natprefix, + upstream: &net.UDPAddr{ + IP: net.IP(upstreamAddr.AsSlice()), + Port: *portInt, + }, + port: port, + mtu: mtu, + buffers: *buffersInt, + } + log.Println(vx.forward()) } }