| Jan Schär | ec03df4 | 2025-02-27 14:30:45 +0100 | [diff] [blame] | 1 | commit 0a454ac56a5f6e9343e0bfafa31fd63d5dc831b5 |
| 2 | Author: Jan Schär <jan@monogon.tech> |
| 3 | Date: Wed Feb 26 18:27:57 2025 +0100 |
| 4 | |
| 5 | Split set elements into batches if needed |
| 6 | |
| 7 | If the number of elements to be added to or removed from a set is large, |
| 8 | they may not all fit into one message, because the size field of a |
| 9 | netlink attribute is a uint16 and would overflow. To support this case, |
| 10 | the elements need to be split into multiple batches. |
| 11 | |
| 12 | Upstream PR: https://github.com/google/nftables/pull/303 |
| 13 | |
| 14 | diff --git a/set.go b/set.go |
| 15 | index 412d75a..4d1dcae 100644 |
| 16 | --- a/set.go |
| 17 | +++ b/set.go |
| 18 | @@ -375,24 +375,31 @@ func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error { |
| 19 | if s.Anonymous { |
| 20 | return errors.New("anonymous sets cannot be updated") |
| 21 | } |
| 22 | + return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM) |
| 23 | +} |
| 24 | |
| 25 | - elements, err := s.makeElemList(vals, s.ID) |
| 26 | - if err != nil { |
| 27 | - return err |
| 28 | +// SetDeleteElements deletes data points from an nftables set. |
| 29 | +func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { |
| 30 | + cc.mu.Lock() |
| 31 | + defer cc.mu.Unlock() |
| 32 | + if s.Anonymous { |
| 33 | + return errors.New("anonymous sets cannot be updated") |
| 34 | } |
| 35 | - cc.messages = append(cc.messages, netlinkMessage{ |
| 36 | - Header: netlink.Header{ |
| 37 | - Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM), |
| 38 | - Flags: netlink.Request | netlink.Acknowledge | netlink.Create, |
| 39 | - }, |
| 40 | - Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), |
| 41 | - }) |
| 42 | - |
| 43 | - return nil |
| 44 | + return cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM) |
| 45 | } |
| 46 | |
| 47 | -func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, error) { |
| 48 | +// maxElemBatchSize is the maximum size in bytes of encoded set elements which |
| 49 | +// are sent in one netlink message. The size field of a netlink attribute is a |
| 50 | +// uint16, and 1024 bytes is more than enough for the per-message headers. |
| 51 | +const maxElemBatchSize = 0x10000 - 1024 |
| 52 | + |
| 53 | +func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error { |
| 54 | + if len(vals) == 0 { |
| 55 | + return nil |
| 56 | + } |
| 57 | var elements []netlink.Attribute |
| 58 | + batchSize := 0 |
| 59 | + var batches [][]netlink.Attribute |
| 60 | |
| 61 | for i, v := range vals { |
| 62 | item := make([]netlink.Attribute, 0) |
| 63 | @@ -404,14 +411,14 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e |
| 64 | |
| 65 | encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Key}}) |
| 66 | if err != nil { |
| 67 | - return nil, fmt.Errorf("marshal key %d: %v", i, err) |
| 68 | + return fmt.Errorf("marshal key %d: %v", i, err) |
| 69 | } |
| 70 | |
| 71 | item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey}) |
| 72 | if len(v.KeyEnd) > 0 { |
| 73 | encodedKeyEnd, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.KeyEnd}}) |
| 74 | if err != nil { |
| 75 | - return nil, fmt.Errorf("marshal key end %d: %v", i, err) |
| 76 | + return fmt.Errorf("marshal key end %d: %v", i, err) |
| 77 | } |
| 78 | item = append(item, netlink.Attribute{Type: NFTA_SET_ELEM_KEY_END | unix.NLA_F_NESTED, Data: encodedKeyEnd}) |
| 79 | } |
| 80 | @@ -431,7 +438,7 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e |
| 81 | {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(v.VerdictData.Kind))}, |
| 82 | }) |
| 83 | if err != nil { |
| 84 | - return nil, fmt.Errorf("marshal item %d: %v", i, err) |
| 85 | + return fmt.Errorf("marshal item %d: %v", i, err) |
| 86 | } |
| 87 | encodedVal = append(encodedVal, encodedKind...) |
| 88 | if len(v.VerdictData.Chain) != 0 { |
| 89 | @@ -439,21 +446,21 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e |
| 90 | {Type: unix.NFTA_SET_ELEM_DATA, Data: []byte(v.VerdictData.Chain + "\x00")}, |
| 91 | }) |
| 92 | if err != nil { |
| 93 | - return nil, fmt.Errorf("marshal item %d: %v", i, err) |
| 94 | + return fmt.Errorf("marshal item %d: %v", i, err) |
| 95 | } |
| 96 | encodedVal = append(encodedVal, encodedChain...) |
| 97 | } |
| 98 | encodedVerdict, err := netlink.MarshalAttributes([]netlink.Attribute{ |
| 99 | {Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}}) |
| 100 | if err != nil { |
| 101 | - return nil, fmt.Errorf("marshal item %d: %v", i, err) |
| 102 | + return fmt.Errorf("marshal item %d: %v", i, err) |
| 103 | } |
| 104 | item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVerdict}) |
| 105 | case len(v.Val) > 0: |
| 106 | // Since v.Val's length is not 0 then, v is a regular map element, need to add to the attributes |
| 107 | encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Val}}) |
| 108 | if err != nil { |
| 109 | - return nil, fmt.Errorf("marshal item %d: %v", i, err) |
| 110 | + return fmt.Errorf("marshal item %d: %v", i, err) |
| 111 | } |
| 112 | |
| 113 | item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}) |
| 114 | @@ -469,22 +476,42 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e |
| 115 | |
| 116 | encodedItem, err := netlink.MarshalAttributes(item) |
| 117 | if err != nil { |
| 118 | - return nil, fmt.Errorf("marshal item %d: %v", i, err) |
| 119 | + return fmt.Errorf("marshal item %d: %v", i, err) |
| 120 | + } |
| 121 | + |
| 122 | + itemSize := unix.NLA_HDRLEN + len(encodedItem) |
| 123 | + if batchSize+itemSize > maxElemBatchSize { |
| 124 | + batches = append(batches, elements) |
| 125 | + elements = nil |
| 126 | + batchSize = 0 |
| 127 | } |
| 128 | elements = append(elements, netlink.Attribute{Type: uint16(i+1) | unix.NLA_F_NESTED, Data: encodedItem}) |
| 129 | + batchSize += itemSize |
| 130 | } |
| 131 | + batches = append(batches, elements) |
| 132 | |
| 133 | - encodedElem, err := netlink.MarshalAttributes(elements) |
| 134 | - if err != nil { |
| 135 | - return nil, fmt.Errorf("marshal elements: %v", err) |
| 136 | - } |
| 137 | + for _, batch := range batches { |
| 138 | + encodedElem, err := netlink.MarshalAttributes(batch) |
| 139 | + if err != nil { |
| 140 | + return fmt.Errorf("marshal elements: %v", err) |
| 141 | + } |
| 142 | |
| 143 | - return []netlink.Attribute{ |
| 144 | - {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, |
| 145 | - {Type: unix.NFTA_LOOKUP_SET_ID, Data: binaryutil.BigEndian.PutUint32(id)}, |
| 146 | - {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, |
| 147 | - {Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem}, |
| 148 | - }, nil |
| 149 | + message := []netlink.Attribute{ |
| 150 | + {Type: unix.NFTA_SET_ELEM_LIST_SET, Data: []byte(s.Name + "\x00")}, |
| 151 | + {Type: unix.NFTA_SET_ELEM_LIST_SET_ID, Data: binaryutil.BigEndian.PutUint32(s.ID)}, |
| 152 | + {Type: unix.NFTA_SET_ELEM_LIST_TABLE, Data: []byte(s.Table.Name + "\x00")}, |
| 153 | + {Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem}, |
| 154 | + } |
| 155 | + |
| 156 | + cc.messages = append(cc.messages, netlinkMessage{ |
| 157 | + Header: netlink.Header{ |
| 158 | + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType), |
| 159 | + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, |
| 160 | + }, |
| 161 | + Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(message)...), |
| 162 | + }) |
| 163 | + } |
| 164 | + return nil |
| 165 | } |
| 166 | |
| 167 | // AddSet adds the specified Set. |
| 168 | @@ -659,22 +686,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { |
| 169 | }) |
| 170 | |
| 171 | // Set the values of the set if initial values were provided. |
| 172 | - if len(vals) > 0 { |
| 173 | - hdrType := unix.NFT_MSG_NEWSETELEM |
| 174 | - elements, err := s.makeElemList(vals, s.ID) |
| 175 | - if err != nil { |
| 176 | - return err |
| 177 | - } |
| 178 | - cc.messages = append(cc.messages, netlinkMessage{ |
| 179 | - Header: netlink.Header{ |
| 180 | - Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType), |
| 181 | - Flags: netlink.Request | netlink.Acknowledge | netlink.Create, |
| 182 | - }, |
| 183 | - Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), |
| 184 | - }) |
| 185 | - } |
| 186 | - |
| 187 | - return nil |
| 188 | + return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM) |
| 189 | } |
| 190 | |
| 191 | // DelSet deletes a specific set, along with all elements it contains. |
| 192 | @@ -694,29 +706,6 @@ func (cc *Conn) DelSet(s *Set) { |
| 193 | }) |
| 194 | } |
| 195 | |
| 196 | -// SetDeleteElements deletes data points from an nftables set. |
| 197 | -func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { |
| 198 | - cc.mu.Lock() |
| 199 | - defer cc.mu.Unlock() |
| 200 | - if s.Anonymous { |
| 201 | - return errors.New("anonymous sets cannot be updated") |
| 202 | - } |
| 203 | - |
| 204 | - elements, err := s.makeElemList(vals, s.ID) |
| 205 | - if err != nil { |
| 206 | - return err |
| 207 | - } |
| 208 | - cc.messages = append(cc.messages, netlinkMessage{ |
| 209 | - Header: netlink.Header{ |
| 210 | - Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), |
| 211 | - Flags: netlink.Request | netlink.Acknowledge | netlink.Create, |
| 212 | - }, |
| 213 | - Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), |
| 214 | - }) |
| 215 | - |
| 216 | - return nil |
| 217 | -} |
| 218 | - |
| 219 | // FlushSet deletes all data points from an nftables set. |
| 220 | func (cc *Conn) FlushSet(s *Set) { |
| 221 | cc.mu.Lock() |
| 222 | @@ -972,8 +961,8 @@ func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) { |
| 223 | defer func() { _ = closer() }() |
| 224 | |
| 225 | data, err := netlink.MarshalAttributes([]netlink.Attribute{ |
| 226 | - {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, |
| 227 | - {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, |
| 228 | + {Type: unix.NFTA_SET_ELEM_LIST_TABLE, Data: []byte(s.Table.Name + "\x00")}, |
| 229 | + {Type: unix.NFTA_SET_ELEM_LIST_SET, Data: []byte(s.Name + "\x00")}, |
| 230 | }) |
| 231 | if err != nil { |
| 232 | return nil, err |