blob: 4723b16b60bea44f90a294e29a520a9322380b15 [file] [log] [blame]
Jan Schärec03df42025-02-27 14:30:45 +01001commit 0a454ac56a5f6e9343e0bfafa31fd63d5dc831b5
2Author: Jan Schär <jan@monogon.tech>
3Date: Wed Feb 26 18:27:57 2025 +0100
4
5 Split set elements into batches if needed
6
7 If the number of elements to be added to or removed from a set is large,
8 they may not all fit into one message, because the size field of a
9 netlink attribute is a uint16 and would overflow. To support this case,
10 the elements need to be split into multiple batches.
11
12 Upstream PR: https://github.com/google/nftables/pull/303
13
14diff --git a/set.go b/set.go
15index 412d75a..4d1dcae 100644
16--- a/set.go
17+++ b/set.go
18@@ -375,24 +375,31 @@ func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error {
19 if s.Anonymous {
20 return errors.New("anonymous sets cannot be updated")
21 }
22+ return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM)
23+}
24
25- elements, err := s.makeElemList(vals, s.ID)
26- if err != nil {
27- return err
28+// SetDeleteElements deletes data points from an nftables set.
29+func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error {
30+ cc.mu.Lock()
31+ defer cc.mu.Unlock()
32+ if s.Anonymous {
33+ return errors.New("anonymous sets cannot be updated")
34 }
35- cc.messages = append(cc.messages, netlinkMessage{
36- Header: netlink.Header{
37- Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM),
38- Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
39- },
40- Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...),
41- })
42-
43- return nil
44+ return cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM)
45 }
46
47-func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, error) {
48+// maxElemBatchSize is the maximum size in bytes of encoded set elements which
49+// are sent in one netlink message. The size field of a netlink attribute is a
50+// uint16, and 1024 bytes is more than enough for the per-message headers.
51+const maxElemBatchSize = 0x10000 - 1024
52+
53+func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error {
54+ if len(vals) == 0 {
55+ return nil
56+ }
57 var elements []netlink.Attribute
58+ batchSize := 0
59+ var batches [][]netlink.Attribute
60
61 for i, v := range vals {
62 item := make([]netlink.Attribute, 0)
63@@ -404,14 +411,14 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
64
65 encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Key}})
66 if err != nil {
67- return nil, fmt.Errorf("marshal key %d: %v", i, err)
68+ return fmt.Errorf("marshal key %d: %v", i, err)
69 }
70
71 item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey})
72 if len(v.KeyEnd) > 0 {
73 encodedKeyEnd, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.KeyEnd}})
74 if err != nil {
75- return nil, fmt.Errorf("marshal key end %d: %v", i, err)
76+ return fmt.Errorf("marshal key end %d: %v", i, err)
77 }
78 item = append(item, netlink.Attribute{Type: NFTA_SET_ELEM_KEY_END | unix.NLA_F_NESTED, Data: encodedKeyEnd})
79 }
80@@ -431,7 +438,7 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
81 {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(v.VerdictData.Kind))},
82 })
83 if err != nil {
84- return nil, fmt.Errorf("marshal item %d: %v", i, err)
85+ return fmt.Errorf("marshal item %d: %v", i, err)
86 }
87 encodedVal = append(encodedVal, encodedKind...)
88 if len(v.VerdictData.Chain) != 0 {
89@@ -439,21 +446,21 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
90 {Type: unix.NFTA_SET_ELEM_DATA, Data: []byte(v.VerdictData.Chain + "\x00")},
91 })
92 if err != nil {
93- return nil, fmt.Errorf("marshal item %d: %v", i, err)
94+ return fmt.Errorf("marshal item %d: %v", i, err)
95 }
96 encodedVal = append(encodedVal, encodedChain...)
97 }
98 encodedVerdict, err := netlink.MarshalAttributes([]netlink.Attribute{
99 {Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}})
100 if err != nil {
101- return nil, fmt.Errorf("marshal item %d: %v", i, err)
102+ return fmt.Errorf("marshal item %d: %v", i, err)
103 }
104 item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVerdict})
105 case len(v.Val) > 0:
106 // Since v.Val's length is not 0 then, v is a regular map element, need to add to the attributes
107 encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Val}})
108 if err != nil {
109- return nil, fmt.Errorf("marshal item %d: %v", i, err)
110+ return fmt.Errorf("marshal item %d: %v", i, err)
111 }
112
113 item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal})
114@@ -469,22 +476,42 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
115
116 encodedItem, err := netlink.MarshalAttributes(item)
117 if err != nil {
118- return nil, fmt.Errorf("marshal item %d: %v", i, err)
119+ return fmt.Errorf("marshal item %d: %v", i, err)
120+ }
121+
122+ itemSize := unix.NLA_HDRLEN + len(encodedItem)
123+ if batchSize+itemSize > maxElemBatchSize {
124+ batches = append(batches, elements)
125+ elements = nil
126+ batchSize = 0
127 }
128 elements = append(elements, netlink.Attribute{Type: uint16(i+1) | unix.NLA_F_NESTED, Data: encodedItem})
129+ batchSize += itemSize
130 }
131+ batches = append(batches, elements)
132
133- encodedElem, err := netlink.MarshalAttributes(elements)
134- if err != nil {
135- return nil, fmt.Errorf("marshal elements: %v", err)
136- }
137+ for _, batch := range batches {
138+ encodedElem, err := netlink.MarshalAttributes(batch)
139+ if err != nil {
140+ return fmt.Errorf("marshal elements: %v", err)
141+ }
142
143- return []netlink.Attribute{
144- {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")},
145- {Type: unix.NFTA_LOOKUP_SET_ID, Data: binaryutil.BigEndian.PutUint32(id)},
146- {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")},
147- {Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem},
148- }, nil
149+ message := []netlink.Attribute{
150+ {Type: unix.NFTA_SET_ELEM_LIST_SET, Data: []byte(s.Name + "\x00")},
151+ {Type: unix.NFTA_SET_ELEM_LIST_SET_ID, Data: binaryutil.BigEndian.PutUint32(s.ID)},
152+ {Type: unix.NFTA_SET_ELEM_LIST_TABLE, Data: []byte(s.Table.Name + "\x00")},
153+ {Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem},
154+ }
155+
156+ cc.messages = append(cc.messages, netlinkMessage{
157+ Header: netlink.Header{
158+ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType),
159+ Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
160+ },
161+ Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(message)...),
162+ })
163+ }
164+ return nil
165 }
166
167 // AddSet adds the specified Set.
168@@ -659,22 +686,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
169 })
170
171 // Set the values of the set if initial values were provided.
172- if len(vals) > 0 {
173- hdrType := unix.NFT_MSG_NEWSETELEM
174- elements, err := s.makeElemList(vals, s.ID)
175- if err != nil {
176- return err
177- }
178- cc.messages = append(cc.messages, netlinkMessage{
179- Header: netlink.Header{
180- Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType),
181- Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
182- },
183- Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...),
184- })
185- }
186-
187- return nil
188+ return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM)
189 }
190
191 // DelSet deletes a specific set, along with all elements it contains.
192@@ -694,29 +706,6 @@ func (cc *Conn) DelSet(s *Set) {
193 })
194 }
195
196-// SetDeleteElements deletes data points from an nftables set.
197-func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error {
198- cc.mu.Lock()
199- defer cc.mu.Unlock()
200- if s.Anonymous {
201- return errors.New("anonymous sets cannot be updated")
202- }
203-
204- elements, err := s.makeElemList(vals, s.ID)
205- if err != nil {
206- return err
207- }
208- cc.messages = append(cc.messages, netlinkMessage{
209- Header: netlink.Header{
210- Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM),
211- Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
212- },
213- Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...),
214- })
215-
216- return nil
217-}
218-
219 // FlushSet deletes all data points from an nftables set.
220 func (cc *Conn) FlushSet(s *Set) {
221 cc.mu.Lock()
222@@ -972,8 +961,8 @@ func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) {
223 defer func() { _ = closer() }()
224
225 data, err := netlink.MarshalAttributes([]netlink.Attribute{
226- {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")},
227- {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")},
228+ {Type: unix.NFTA_SET_ELEM_LIST_TABLE, Data: []byte(s.Table.Name + "\x00")},
229+ {Type: unix.NFTA_SET_ELEM_LIST_SET, Data: []byte(s.Name + "\x00")},
230 })
231 if err != nil {
232 return nil, err