diff --git a/platform/firewall/nftables_firewall.go b/platform/firewall/nftables_firewall.go index 5a1b2bc4f..795808c52 100644 --- a/platform/firewall/nftables_firewall.go +++ b/platform/firewall/nftables_firewall.go @@ -164,6 +164,9 @@ func (f *NftablesFirewall) SetupMonitFirewall() error { // Add jump to jobs chain first (so job rules are checked before agent rules) f.addJumpToJobsChain() + // Allow return traffic for established/related connections + f.addConntrackRule(f.monitChain) + // Add allow rule for root (UID 0) f.addMonitAllowRule() @@ -255,6 +258,9 @@ func (f *NftablesFirewall) SetupNATSFirewall(mbusURL string) error { // Flush NATS chain (removes old rules for previous IPs) f.conn.FlushChain(f.natsChain) + // Allow return traffic for established/related connections + f.addConntrackRule(f.natsChain) + // Add rules for each resolved IP for _, addr := range addrs { f.addNATSAllowRule(addr, port) @@ -366,6 +372,14 @@ func (f *NftablesFirewall) ensureNATSChain() { f.conn.AddChain(f.natsChain) } +func (f *NftablesFirewall) addConntrackRule(chain *nftables.Chain) { + f.conn.AddRule(&nftables.Rule{ + Table: f.table, + Chain: chain, + Exprs: buildConntrackEstablishedRelatedExprs(), + }) +} + func (f *NftablesFirewall) addMonitAllowRule() { // Rule: meta skuid 0 ip daddr 127.0.0.1 tcp dport 2822 accept exprs := buildUIDMatchExprs(0) @@ -528,6 +542,38 @@ func (f *NftablesFirewall) addCgroupRule(inodeID uint64, cgroupPath string) erro return nil } +// buildConntrackEstablishedRelatedExprs creates expressions for: +// +// ct state established,related accept +// +// This allows return traffic for already-established connections, preventing +// existing connections from being broken when firewall rules are reloaded. +func buildConntrackEstablishedRelatedExprs() []expr.Any { + ctStateMask := expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED + maskBytes := make([]byte, 4) + binary.NativeEndian.PutUint32(maskBytes, ctStateMask) + + return []expr.Any{ + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: 1, + }, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: maskBytes, + Xor: []byte{0, 0, 0, 0}, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte{0, 0, 0, 0}, + }, + &expr.Verdict{Kind: expr.VerdictAccept}, + } +} + // buildLoopbackDestExprs creates expressions for matching IPv4 loopback destination. // Note: IPv6 loopback (::1) is intentionally not protected because monit only // binds to 127.0.0.1:2822 (see jobsupervisor/monit/provider.go). diff --git a/platform/firewall/nftables_firewall_test.go b/platform/firewall/nftables_firewall_test.go index 02c7ec49a..d9aa39582 100644 --- a/platform/firewall/nftables_firewall_test.go +++ b/platform/firewall/nftables_firewall_test.go @@ -55,7 +55,7 @@ var _ = Describe("NftablesFirewall", func() { Expect(fakeConn.AddTableCallCount()).To(Equal(1)) Expect(fakeConn.AddChainCallCount()).To(Equal(2)) // jobs chain + monit chain Expect(fakeConn.FlushChainCallCount()).To(Equal(1)) - Expect(fakeConn.AddRuleCallCount()).To(Equal(3)) // jump + allow + block + Expect(fakeConn.AddRuleCallCount()).To(Equal(4)) // jump + conntrack + allow + block Expect(fakeConn.FlushCallCount()).To(Equal(1)) }) @@ -107,16 +107,34 @@ var _ = Describe("NftablesFirewall", func() { Expect(verdict.Chain).To(Equal("monit_access_jobs")) }) - It("adds allow rule for UID 0 after jump rule", func() { + It("adds conntrack established,related rule after jump rule", func() { err := manager.SetupMonitFirewall() Expect(err).NotTo(HaveOccurred()) - // Second rule should be the allow rule (has UID match expressions) - allowRule := fakeConn.AddRuleArgsForCall(1) + ctRule := fakeConn.AddRuleArgsForCall(1) + Expect(ctRule.Chain.Name).To(Equal("monit_access")) + + hasCt := false + hasAccept := false + for _, e := range ctRule.Exprs { + if ct, ok := e.(*expr.Ct); ok && ct.Key == expr.CtKeySTATE { + hasCt = true + } + if verdict, ok := e.(*expr.Verdict); ok && verdict.Kind == expr.VerdictAccept { + hasAccept = true + } + } + Expect(hasCt).To(BeTrue(), "rule should load ct state") + Expect(hasAccept).To(BeTrue(), "rule should accept") + }) + + It("adds allow rule for UID 0 after conntrack rule", func() { + err := manager.SetupMonitFirewall() + Expect(err).NotTo(HaveOccurred()) + + allowRule := fakeConn.AddRuleArgsForCall(2) Expect(allowRule.Chain.Name).To(Equal("monit_access")) - // The allow rule has more expressions (UID match + loopback + port + accept) - // Block rule has fewer (loopback + port + drop) - blockRule := fakeConn.AddRuleArgsForCall(2) + blockRule := fakeConn.AddRuleArgsForCall(3) Expect(len(allowRule.Exprs)).To(BeNumerically(">", len(blockRule.Exprs))) }) @@ -330,8 +348,8 @@ var _ = Describe("NftablesFirewall", func() { Expect(fakeConn.AddTableCallCount()).To(Equal(1)) Expect(fakeConn.AddChainCallCount()).To(Equal(1)) Expect(fakeConn.FlushChainCallCount()).To(Equal(1)) - // One allow rule + one block rule - Expect(fakeConn.AddRuleCallCount()).To(Equal(2)) + // conntrack + one allow rule + one block rule + Expect(fakeConn.AddRuleCallCount()).To(Equal(3)) Expect(fakeConn.FlushCallCount()).To(Equal(1)) }) @@ -344,6 +362,27 @@ var _ = Describe("NftablesFirewall", func() { Expect(chain.Type).To(Equal(nftables.ChainTypeFilter)) Expect(chain.Hooknum).To(Equal(nftables.ChainHookOutput)) }) + + It("adds conntrack established,related rule as first rule", func() { + err := manager.SetupNATSFirewall("nats://192.168.1.100:4222") + Expect(err).NotTo(HaveOccurred()) + + ctRule := fakeConn.AddRuleArgsForCall(0) + Expect(ctRule.Chain.Name).To(Equal("nats_access")) + + hasCt := false + hasAccept := false + for _, e := range ctRule.Exprs { + if ct, ok := e.(*expr.Ct); ok && ct.Key == expr.CtKeySTATE { + hasCt = true + } + if verdict, ok := e.(*expr.Verdict); ok && verdict.Kind == expr.VerdictAccept { + hasAccept = true + } + } + Expect(hasCt).To(BeTrue(), "rule should load ct state") + Expect(hasAccept).To(BeTrue(), "rule should accept") + }) }) Context("with an IPv6 address URL", func() { @@ -351,7 +390,7 @@ var _ = Describe("NftablesFirewall", func() { err := manager.SetupNATSFirewall("nats://user:pass@[2001:db8::1]:4222") Expect(err).NotTo(HaveOccurred()) - Expect(fakeConn.AddRuleCallCount()).To(Equal(2)) + Expect(fakeConn.AddRuleCallCount()).To(Equal(3)) Expect(fakeConn.FlushCallCount()).To(Equal(1)) }) }) @@ -369,8 +408,8 @@ var _ = Describe("NftablesFirewall", func() { Expect(fakeResolver.LookupIPCallCount()).To(Equal(1)) Expect(fakeResolver.LookupIPArgsForCall(0)).To(Equal("nats.example.com")) - // Two IPs * 2 rules each = 4 rules - Expect(fakeConn.AddRuleCallCount()).To(Equal(4)) + // 1 conntrack + two IPs * 2 rules each = 5 rules + Expect(fakeConn.AddRuleCallCount()).To(Equal(5)) }) It("handles DNS resolution failure gracefully", func() { @@ -389,7 +428,7 @@ var _ = Describe("NftablesFirewall", func() { err := manager.SetupNATSFirewall("nats://192.168.1.100") Expect(err).NotTo(HaveOccurred()) - Expect(fakeConn.AddRuleCallCount()).To(Equal(2)) + Expect(fakeConn.AddRuleCallCount()).To(Equal(3)) }) }) @@ -398,7 +437,7 @@ var _ = Describe("NftablesFirewall", func() { err := manager.SetupNATSFirewall("nats://192.168.1.100:5222") Expect(err).NotTo(HaveOccurred()) - Expect(fakeConn.AddRuleCallCount()).To(Equal(2)) + Expect(fakeConn.AddRuleCallCount()).To(Equal(3)) }) }) @@ -458,7 +497,7 @@ var _ = Describe("NftablesFirewall", func() { Expect(fakeConn.AddTableCallCount()).To(Equal(1)) Expect(fakeConn.AddChainCallCount()).To(Equal(1)) - Expect(fakeConn.AddRuleCallCount()).To(Equal(2)) + Expect(fakeConn.AddRuleCallCount()).To(Equal(3)) }) It("returns nil on success", func() {