| commit 0a454ac56a5f6e9343e0bfafa31fd63d5dc831b5 |
| Author: Jan Schรคr <jan@monogon.tech> |
| Date: Wed Feb 26 18:27:57 2025 +0100 |
| |
| Split set elements into batches if needed |
| |
| If the number of elements to be added to or removed from a set is large, |
| they may not all fit into one message, because the size field of a |
| netlink attribute is a uint16 and would overflow. To support this case, |
| the elements need to be split into multiple batches. |
| |
| Upstream PR: https://github.com/google/nftables/pull/303 |
| |
| diff --git a/set.go b/set.go |
| index 412d75a..4d1dcae 100644 |
| --- a/set.go |
| +++ b/set.go |
| @@ -375,24 +375,31 @@ func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error { |
| if s.Anonymous { |
| return errors.New("anonymous sets cannot be updated") |
| } |
| + return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM) |
| +} |
| |
| - elements, err := s.makeElemList(vals, s.ID) |
| - if err != nil { |
| - return err |
| +// SetDeleteElements deletes data points from an nftables set. |
| +func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { |
| + cc.mu.Lock() |
| + defer cc.mu.Unlock() |
| + if s.Anonymous { |
| + return errors.New("anonymous sets cannot be updated") |
| } |
| - cc.messages = append(cc.messages, netlinkMessage{ |
| - Header: netlink.Header{ |
| - Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM), |
| - Flags: netlink.Request | netlink.Acknowledge | netlink.Create, |
| - }, |
| - Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), |
| - }) |
| - |
| - return nil |
| + return cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM) |
| } |
| |
| -func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, error) { |
| +// maxElemBatchSize is the maximum size in bytes of encoded set elements which |
| +// are sent in one netlink message. The size field of a netlink attribute is a |
| +// uint16, and 1024 bytes is more than enough for the per-message headers. |
| +const maxElemBatchSize = 0x10000 - 1024 |
| + |
| +func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error { |
| + if len(vals) == 0 { |
| + return nil |
| + } |
| var elements []netlink.Attribute |
| + batchSize := 0 |
| + var batches [][]netlink.Attribute |
| |
| for i, v := range vals { |
| item := make([]netlink.Attribute, 0) |
| @@ -404,14 +411,14 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e |
| |
| encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Key}}) |
| if err != nil { |
| - return nil, fmt.Errorf("marshal key %d: %v", i, err) |
| + return fmt.Errorf("marshal key %d: %v", i, err) |
| } |
| |
| item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey}) |
| if len(v.KeyEnd) > 0 { |
| encodedKeyEnd, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.KeyEnd}}) |
| if err != nil { |
| - return nil, fmt.Errorf("marshal key end %d: %v", i, err) |
| + return fmt.Errorf("marshal key end %d: %v", i, err) |
| } |
| item = append(item, netlink.Attribute{Type: NFTA_SET_ELEM_KEY_END | unix.NLA_F_NESTED, Data: encodedKeyEnd}) |
| } |
| @@ -431,7 +438,7 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e |
| {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(v.VerdictData.Kind))}, |
| }) |
| if err != nil { |
| - return nil, fmt.Errorf("marshal item %d: %v", i, err) |
| + return fmt.Errorf("marshal item %d: %v", i, err) |
| } |
| encodedVal = append(encodedVal, encodedKind...) |
| if len(v.VerdictData.Chain) != 0 { |
| @@ -439,21 +446,21 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e |
| {Type: unix.NFTA_SET_ELEM_DATA, Data: []byte(v.VerdictData.Chain + "\x00")}, |
| }) |
| if err != nil { |
| - return nil, fmt.Errorf("marshal item %d: %v", i, err) |
| + return fmt.Errorf("marshal item %d: %v", i, err) |
| } |
| encodedVal = append(encodedVal, encodedChain...) |
| } |
| encodedVerdict, err := netlink.MarshalAttributes([]netlink.Attribute{ |
| {Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}}) |
| if err != nil { |
| - return nil, fmt.Errorf("marshal item %d: %v", i, err) |
| + return fmt.Errorf("marshal item %d: %v", i, err) |
| } |
| item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVerdict}) |
| case len(v.Val) > 0: |
| // Since v.Val's length is not 0 then, v is a regular map element, need to add to the attributes |
| encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Val}}) |
| if err != nil { |
| - return nil, fmt.Errorf("marshal item %d: %v", i, err) |
| + return fmt.Errorf("marshal item %d: %v", i, err) |
| } |
| |
| item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}) |
| @@ -469,22 +476,42 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e |
| |
| encodedItem, err := netlink.MarshalAttributes(item) |
| if err != nil { |
| - return nil, fmt.Errorf("marshal item %d: %v", i, err) |
| + return fmt.Errorf("marshal item %d: %v", i, err) |
| + } |
| + |
| + itemSize := unix.NLA_HDRLEN + len(encodedItem) |
| + if batchSize+itemSize > maxElemBatchSize { |
| + batches = append(batches, elements) |
| + elements = nil |
| + batchSize = 0 |
| } |
| elements = append(elements, netlink.Attribute{Type: uint16(i+1) | unix.NLA_F_NESTED, Data: encodedItem}) |
| + batchSize += itemSize |
| } |
| + batches = append(batches, elements) |
| |
| - encodedElem, err := netlink.MarshalAttributes(elements) |
| - if err != nil { |
| - return nil, fmt.Errorf("marshal elements: %v", err) |
| - } |
| + for _, batch := range batches { |
| + encodedElem, err := netlink.MarshalAttributes(batch) |
| + if err != nil { |
| + return fmt.Errorf("marshal elements: %v", err) |
| + } |
| |
| - return []netlink.Attribute{ |
| - {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, |
| - {Type: unix.NFTA_LOOKUP_SET_ID, Data: binaryutil.BigEndian.PutUint32(id)}, |
| - {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, |
| - {Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem}, |
| - }, nil |
| + message := []netlink.Attribute{ |
| + {Type: unix.NFTA_SET_ELEM_LIST_SET, Data: []byte(s.Name + "\x00")}, |
| + {Type: unix.NFTA_SET_ELEM_LIST_SET_ID, Data: binaryutil.BigEndian.PutUint32(s.ID)}, |
| + {Type: unix.NFTA_SET_ELEM_LIST_TABLE, Data: []byte(s.Table.Name + "\x00")}, |
| + {Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem}, |
| + } |
| + |
| + cc.messages = append(cc.messages, netlinkMessage{ |
| + Header: netlink.Header{ |
| + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType), |
| + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, |
| + }, |
| + Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(message)...), |
| + }) |
| + } |
| + return nil |
| } |
| |
| // AddSet adds the specified Set. |
| @@ -659,22 +686,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { |
| }) |
| |
| // Set the values of the set if initial values were provided. |
| - if len(vals) > 0 { |
| - hdrType := unix.NFT_MSG_NEWSETELEM |
| - elements, err := s.makeElemList(vals, s.ID) |
| - if err != nil { |
| - return err |
| - } |
| - cc.messages = append(cc.messages, netlinkMessage{ |
| - Header: netlink.Header{ |
| - Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType), |
| - Flags: netlink.Request | netlink.Acknowledge | netlink.Create, |
| - }, |
| - Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), |
| - }) |
| - } |
| - |
| - return nil |
| + return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM) |
| } |
| |
| // DelSet deletes a specific set, along with all elements it contains. |
| @@ -694,29 +706,6 @@ func (cc *Conn) DelSet(s *Set) { |
| }) |
| } |
| |
| -// SetDeleteElements deletes data points from an nftables set. |
| -func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { |
| - cc.mu.Lock() |
| - defer cc.mu.Unlock() |
| - if s.Anonymous { |
| - return errors.New("anonymous sets cannot be updated") |
| - } |
| - |
| - elements, err := s.makeElemList(vals, s.ID) |
| - if err != nil { |
| - return err |
| - } |
| - cc.messages = append(cc.messages, netlinkMessage{ |
| - Header: netlink.Header{ |
| - Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), |
| - Flags: netlink.Request | netlink.Acknowledge | netlink.Create, |
| - }, |
| - Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), |
| - }) |
| - |
| - return nil |
| -} |
| - |
| // FlushSet deletes all data points from an nftables set. |
| func (cc *Conn) FlushSet(s *Set) { |
| cc.mu.Lock() |
| @@ -972,8 +961,8 @@ func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) { |
| defer func() { _ = closer() }() |
| |
| data, err := netlink.MarshalAttributes([]netlink.Attribute{ |
| - {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, |
| - {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, |
| + {Type: unix.NFTA_SET_ELEM_LIST_TABLE, Data: []byte(s.Table.Name + "\x00")}, |
| + {Type: unix.NFTA_SET_ELEM_LIST_SET, Data: []byte(s.Name + "\x00")}, |
| }) |
| if err != nil { |
| return nil, err |