Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions platform/firewall/nftables_firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ func (f *NftablesFirewall) SetupMonitFirewall() error {
// Add jump to jobs chain first (so job rules are checked before agent rules)
f.addJumpToJobsChain()

// Allow return traffic for established/related connections
f.addConntrackRule(f.monitChain)

// Add allow rule for root (UID 0)
f.addMonitAllowRule()

Expand Down Expand Up @@ -255,6 +258,9 @@ func (f *NftablesFirewall) SetupNATSFirewall(mbusURL string) error {
// Flush NATS chain (removes old rules for previous IPs)
f.conn.FlushChain(f.natsChain)

// Allow return traffic for established/related connections
f.addConntrackRule(f.natsChain)

// Add rules for each resolved IP
for _, addr := range addrs {
f.addNATSAllowRule(addr, port)
Expand Down Expand Up @@ -366,6 +372,14 @@ func (f *NftablesFirewall) ensureNATSChain() {
f.conn.AddChain(f.natsChain)
}

func (f *NftablesFirewall) addConntrackRule(chain *nftables.Chain) {
f.conn.AddRule(&nftables.Rule{
Table: f.table,
Chain: chain,
Exprs: buildConntrackEstablishedRelatedExprs(),
})
}

func (f *NftablesFirewall) addMonitAllowRule() {
// Rule: meta skuid 0 ip daddr 127.0.0.1 tcp dport 2822 accept
exprs := buildUIDMatchExprs(0)
Expand Down Expand Up @@ -528,6 +542,38 @@ func (f *NftablesFirewall) addCgroupRule(inodeID uint64, cgroupPath string) erro
return nil
}

// buildConntrackEstablishedRelatedExprs creates expressions for:
//
// ct state established,related accept
//
// This allows return traffic for already-established connections, preventing
// existing connections from being broken when firewall rules are reloaded.
func buildConntrackEstablishedRelatedExprs() []expr.Any {
ctStateMask := expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED
maskBytes := make([]byte, 4)
binary.NativeEndian.PutUint32(maskBytes, ctStateMask)

return []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: maskBytes,
Xor: []byte{0, 0, 0, 0},
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{Kind: expr.VerdictAccept},
}
}

// buildLoopbackDestExprs creates expressions for matching IPv4 loopback destination.
// Note: IPv6 loopback (::1) is intentionally not protected because monit only
// binds to 127.0.0.1:2822 (see jobsupervisor/monit/provider.go).
Expand Down
69 changes: 54 additions & 15 deletions platform/firewall/nftables_firewall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ var _ = Describe("NftablesFirewall", func() {
Expect(fakeConn.AddTableCallCount()).To(Equal(1))
Expect(fakeConn.AddChainCallCount()).To(Equal(2)) // jobs chain + monit chain
Expect(fakeConn.FlushChainCallCount()).To(Equal(1))
Expect(fakeConn.AddRuleCallCount()).To(Equal(3)) // jump + allow + block
Expect(fakeConn.AddRuleCallCount()).To(Equal(4)) // jump + conntrack + allow + block
Expect(fakeConn.FlushCallCount()).To(Equal(1))
})

Expand Down Expand Up @@ -107,16 +107,34 @@ var _ = Describe("NftablesFirewall", func() {
Expect(verdict.Chain).To(Equal("monit_access_jobs"))
})

It("adds allow rule for UID 0 after jump rule", func() {
It("adds conntrack established,related rule after jump rule", func() {
err := manager.SetupMonitFirewall()
Expect(err).NotTo(HaveOccurred())

// Second rule should be the allow rule (has UID match expressions)
allowRule := fakeConn.AddRuleArgsForCall(1)
ctRule := fakeConn.AddRuleArgsForCall(1)
Expect(ctRule.Chain.Name).To(Equal("monit_access"))

hasCt := false
hasAccept := false
for _, e := range ctRule.Exprs {
if ct, ok := e.(*expr.Ct); ok && ct.Key == expr.CtKeySTATE {
hasCt = true
}
if verdict, ok := e.(*expr.Verdict); ok && verdict.Kind == expr.VerdictAccept {
hasAccept = true
}
}
Expect(hasCt).To(BeTrue(), "rule should load ct state")
Expect(hasAccept).To(BeTrue(), "rule should accept")
})

It("adds allow rule for UID 0 after conntrack rule", func() {
err := manager.SetupMonitFirewall()
Expect(err).NotTo(HaveOccurred())

allowRule := fakeConn.AddRuleArgsForCall(2)
Expect(allowRule.Chain.Name).To(Equal("monit_access"))
// The allow rule has more expressions (UID match + loopback + port + accept)
// Block rule has fewer (loopback + port + drop)
blockRule := fakeConn.AddRuleArgsForCall(2)
blockRule := fakeConn.AddRuleArgsForCall(3)
Expect(len(allowRule.Exprs)).To(BeNumerically(">", len(blockRule.Exprs)))
})

Expand Down Expand Up @@ -330,8 +348,8 @@ var _ = Describe("NftablesFirewall", func() {
Expect(fakeConn.AddTableCallCount()).To(Equal(1))
Expect(fakeConn.AddChainCallCount()).To(Equal(1))
Expect(fakeConn.FlushChainCallCount()).To(Equal(1))
// One allow rule + one block rule
Expect(fakeConn.AddRuleCallCount()).To(Equal(2))
// conntrack + one allow rule + one block rule
Expect(fakeConn.AddRuleCallCount()).To(Equal(3))
Expect(fakeConn.FlushCallCount()).To(Equal(1))
})

Expand All @@ -344,14 +362,35 @@ var _ = Describe("NftablesFirewall", func() {
Expect(chain.Type).To(Equal(nftables.ChainTypeFilter))
Expect(chain.Hooknum).To(Equal(nftables.ChainHookOutput))
})

It("adds conntrack established,related rule as first rule", func() {
err := manager.SetupNATSFirewall("nats://192.168.1.100:4222")
Expect(err).NotTo(HaveOccurred())

ctRule := fakeConn.AddRuleArgsForCall(0)
Expect(ctRule.Chain.Name).To(Equal("nats_access"))

hasCt := false
hasAccept := false
for _, e := range ctRule.Exprs {
if ct, ok := e.(*expr.Ct); ok && ct.Key == expr.CtKeySTATE {
hasCt = true
}
if verdict, ok := e.(*expr.Verdict); ok && verdict.Kind == expr.VerdictAccept {
hasAccept = true
}
}
Expect(hasCt).To(BeTrue(), "rule should load ct state")
Expect(hasAccept).To(BeTrue(), "rule should accept")
})
})

Context("with an IPv6 address URL", func() {
It("creates rules for the IPv6 address", func() {
err := manager.SetupNATSFirewall("nats://user:pass@[2001:db8::1]:4222")
Expect(err).NotTo(HaveOccurred())

Expect(fakeConn.AddRuleCallCount()).To(Equal(2))
Expect(fakeConn.AddRuleCallCount()).To(Equal(3))
Expect(fakeConn.FlushCallCount()).To(Equal(1))
})
})
Expand All @@ -369,8 +408,8 @@ var _ = Describe("NftablesFirewall", func() {
Expect(fakeResolver.LookupIPCallCount()).To(Equal(1))
Expect(fakeResolver.LookupIPArgsForCall(0)).To(Equal("nats.example.com"))

// Two IPs * 2 rules each = 4 rules
Expect(fakeConn.AddRuleCallCount()).To(Equal(4))
// 1 conntrack + two IPs * 2 rules each = 5 rules
Expect(fakeConn.AddRuleCallCount()).To(Equal(5))
})

It("handles DNS resolution failure gracefully", func() {
Expand All @@ -389,7 +428,7 @@ var _ = Describe("NftablesFirewall", func() {
err := manager.SetupNATSFirewall("nats://192.168.1.100")
Expect(err).NotTo(HaveOccurred())

Expect(fakeConn.AddRuleCallCount()).To(Equal(2))
Expect(fakeConn.AddRuleCallCount()).To(Equal(3))
})
})

Expand All @@ -398,7 +437,7 @@ var _ = Describe("NftablesFirewall", func() {
err := manager.SetupNATSFirewall("nats://192.168.1.100:5222")
Expect(err).NotTo(HaveOccurred())

Expect(fakeConn.AddRuleCallCount()).To(Equal(2))
Expect(fakeConn.AddRuleCallCount()).To(Equal(3))
})
})

Expand Down Expand Up @@ -458,7 +497,7 @@ var _ = Describe("NftablesFirewall", func() {

Expect(fakeConn.AddTableCallCount()).To(Equal(1))
Expect(fakeConn.AddChainCallCount()).To(Equal(1))
Expect(fakeConn.AddRuleCallCount()).To(Equal(2))
Expect(fakeConn.AddRuleCallCount()).To(Equal(3))
})

It("returns nil on success", func() {
Expand Down