blob: 4723b16b60bea44f90a294e29a520a9322380b15 [file] [log] [blame]
commit 0a454ac56a5f6e9343e0bfafa31fd63d5dc831b5
Author: Jan Schรคr <jan@monogon.tech>
Date: Wed Feb 26 18:27:57 2025 +0100
Split set elements into batches if needed
If the number of elements to be added to or removed from a set is large,
they may not all fit into one message, because the size field of a
netlink attribute is a uint16 and would overflow. To support this case,
the elements need to be split into multiple batches.
Upstream PR: https://github.com/google/nftables/pull/303
diff --git a/set.go b/set.go
index 412d75a..4d1dcae 100644
--- a/set.go
+++ b/set.go
@@ -375,24 +375,31 @@ func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error {
if s.Anonymous {
return errors.New("anonymous sets cannot be updated")
}
+ return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM)
+}
- elements, err := s.makeElemList(vals, s.ID)
- if err != nil {
- return err
+// SetDeleteElements deletes data points from an nftables set.
+func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ if s.Anonymous {
+ return errors.New("anonymous sets cannot be updated")
}
- cc.messages = append(cc.messages, netlinkMessage{
- Header: netlink.Header{
- Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM),
- Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
- },
- Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...),
- })
-
- return nil
+ return cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM)
}
-func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, error) {
+// maxElemBatchSize is the maximum size in bytes of encoded set elements which
+// are sent in one netlink message. The size field of a netlink attribute is a
+// uint16, and 1024 bytes is more than enough for the per-message headers.
+const maxElemBatchSize = 0x10000 - 1024
+
+func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error {
+ if len(vals) == 0 {
+ return nil
+ }
var elements []netlink.Attribute
+ batchSize := 0
+ var batches [][]netlink.Attribute
for i, v := range vals {
item := make([]netlink.Attribute, 0)
@@ -404,14 +411,14 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Key}})
if err != nil {
- return nil, fmt.Errorf("marshal key %d: %v", i, err)
+ return fmt.Errorf("marshal key %d: %v", i, err)
}
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey})
if len(v.KeyEnd) > 0 {
encodedKeyEnd, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.KeyEnd}})
if err != nil {
- return nil, fmt.Errorf("marshal key end %d: %v", i, err)
+ return fmt.Errorf("marshal key end %d: %v", i, err)
}
item = append(item, netlink.Attribute{Type: NFTA_SET_ELEM_KEY_END | unix.NLA_F_NESTED, Data: encodedKeyEnd})
}
@@ -431,7 +438,7 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(v.VerdictData.Kind))},
})
if err != nil {
- return nil, fmt.Errorf("marshal item %d: %v", i, err)
+ return fmt.Errorf("marshal item %d: %v", i, err)
}
encodedVal = append(encodedVal, encodedKind...)
if len(v.VerdictData.Chain) != 0 {
@@ -439,21 +446,21 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
{Type: unix.NFTA_SET_ELEM_DATA, Data: []byte(v.VerdictData.Chain + "\x00")},
})
if err != nil {
- return nil, fmt.Errorf("marshal item %d: %v", i, err)
+ return fmt.Errorf("marshal item %d: %v", i, err)
}
encodedVal = append(encodedVal, encodedChain...)
}
encodedVerdict, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}})
if err != nil {
- return nil, fmt.Errorf("marshal item %d: %v", i, err)
+ return fmt.Errorf("marshal item %d: %v", i, err)
}
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVerdict})
case len(v.Val) > 0:
// Since v.Val's length is not 0 then, v is a regular map element, need to add to the attributes
encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Val}})
if err != nil {
- return nil, fmt.Errorf("marshal item %d: %v", i, err)
+ return fmt.Errorf("marshal item %d: %v", i, err)
}
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal})
@@ -469,22 +476,42 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
encodedItem, err := netlink.MarshalAttributes(item)
if err != nil {
- return nil, fmt.Errorf("marshal item %d: %v", i, err)
+ return fmt.Errorf("marshal item %d: %v", i, err)
+ }
+
+ itemSize := unix.NLA_HDRLEN + len(encodedItem)
+ if batchSize+itemSize > maxElemBatchSize {
+ batches = append(batches, elements)
+ elements = nil
+ batchSize = 0
}
elements = append(elements, netlink.Attribute{Type: uint16(i+1) | unix.NLA_F_NESTED, Data: encodedItem})
+ batchSize += itemSize
}
+ batches = append(batches, elements)
- encodedElem, err := netlink.MarshalAttributes(elements)
- if err != nil {
- return nil, fmt.Errorf("marshal elements: %v", err)
- }
+ for _, batch := range batches {
+ encodedElem, err := netlink.MarshalAttributes(batch)
+ if err != nil {
+ return fmt.Errorf("marshal elements: %v", err)
+ }
- return []netlink.Attribute{
- {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")},
- {Type: unix.NFTA_LOOKUP_SET_ID, Data: binaryutil.BigEndian.PutUint32(id)},
- {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")},
- {Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem},
- }, nil
+ message := []netlink.Attribute{
+ {Type: unix.NFTA_SET_ELEM_LIST_SET, Data: []byte(s.Name + "\x00")},
+ {Type: unix.NFTA_SET_ELEM_LIST_SET_ID, Data: binaryutil.BigEndian.PutUint32(s.ID)},
+ {Type: unix.NFTA_SET_ELEM_LIST_TABLE, Data: []byte(s.Table.Name + "\x00")},
+ {Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem},
+ }
+
+ cc.messages = append(cc.messages, netlinkMessage{
+ Header: netlink.Header{
+ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType),
+ Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
+ },
+ Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(message)...),
+ })
+ }
+ return nil
}
// AddSet adds the specified Set.
@@ -659,22 +686,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
})
// Set the values of the set if initial values were provided.
- if len(vals) > 0 {
- hdrType := unix.NFT_MSG_NEWSETELEM
- elements, err := s.makeElemList(vals, s.ID)
- if err != nil {
- return err
- }
- cc.messages = append(cc.messages, netlinkMessage{
- Header: netlink.Header{
- Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType),
- Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
- },
- Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...),
- })
- }
-
- return nil
+ return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM)
}
// DelSet deletes a specific set, along with all elements it contains.
@@ -694,29 +706,6 @@ func (cc *Conn) DelSet(s *Set) {
})
}
-// SetDeleteElements deletes data points from an nftables set.
-func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error {
- cc.mu.Lock()
- defer cc.mu.Unlock()
- if s.Anonymous {
- return errors.New("anonymous sets cannot be updated")
- }
-
- elements, err := s.makeElemList(vals, s.ID)
- if err != nil {
- return err
- }
- cc.messages = append(cc.messages, netlinkMessage{
- Header: netlink.Header{
- Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM),
- Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
- },
- Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...),
- })
-
- return nil
-}
-
// FlushSet deletes all data points from an nftables set.
func (cc *Conn) FlushSet(s *Set) {
cc.mu.Lock()
@@ -972,8 +961,8 @@ func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) {
defer func() { _ = closer() }()
data, err := netlink.MarshalAttributes([]netlink.Attribute{
- {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")},
- {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")},
+ {Type: unix.NFTA_SET_ELEM_LIST_TABLE, Data: []byte(s.Table.Name + "\x00")},
+ {Type: unix.NFTA_SET_ELEM_LIST_SET, Data: []byte(s.Name + "\x00")},
})
if err != nil {
return nil, err