From 472cb1806090017ecf9ddec23edd36a14bff88ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Louren=C3=A7o=20Vales?= <133565059+lourencovales@users.noreply.github.com> Date: Sun, 5 Oct 2025 21:22:48 +0200 Subject: [PATCH] added more robust validation; small changes --- engine/resources/cloudflare_dns.go | 60 +++++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 10 deletions(-) diff --git a/engine/resources/cloudflare_dns.go b/engine/resources/cloudflare_dns.go index 3a3a74ba..7792e46e 100644 --- a/engine/resources/cloudflare_dns.go +++ b/engine/resources/cloudflare_dns.go @@ -32,6 +32,7 @@ package resources import ( "context" "fmt" + "strings" "github.com/purpleidea/mgmt/engine" "github.com/purpleidea/mgmt/engine/traits" @@ -73,6 +74,10 @@ type CloudflareDNSRes struct { // Type (e.g., IP address for A records, hostname for CNAME records). Content string `lang:"content"` + // Data is a value that's specific for SRV records, containing the priority, + // weight, port, and SRV targets. + Data *dns.SRVRecordData `lang:"srv_data"` + // Priority is the priority value for records that support it (e.g., MX // records). This is a pointer to distinguish between an unset value and // a zero value. @@ -152,8 +157,12 @@ func (obj *CloudflareDNSRes) Validate() error { return fmt.Errorf("content is required when state is 'exists'") } + if obj.Type == "MX" && obj.Priority == nil { + return fmt.Errorf("priority is required for MX records") + } + if obj.MetaParams().Poll == 0 || obj.MetaParams().Poll < 1 { // CF accepts ~4req/s so this is good enough - return fmt.Errorf("cloudflare:dns requires polling, set Meta:poll param (e.g., 60 seconds), min. 1s") + return fmt.Errorf("cloudflare:dns requires polling, set Meta:poll param (e.g., 60 seconds), min. 1s") } return nil @@ -233,7 +242,8 @@ func (obj *CloudflareDNSRes) CheckApply(ctx context.Context, apply bool) (bool, } } - // List existing records + // we're using `contains` so as to get the candidates, as `exact` might not + // give the expected results depending on how the user specified it. listParams := dns.RecordListParams{ ZoneID: cloudflare.F(obj.zoneID), Name: cloudflare.F(dns.RecordListParamsName{ @@ -247,10 +257,15 @@ func (obj *CloudflareDNSRes) CheckApply(ctx context.Context, apply bool) (bool, return false, errwrap.Wrapf(err, "failed to list DNS records") } - recordExists := len(recordList.Result) > 0 + // here we filter to find the exact match + recordExists := false var record dns.RecordResponse - if recordExists { - record = recordList.Result[0] + for _, r := range recordList.Result { + if obj.matchesRecordName(r.Name) { + record = r + recordExists = true + break + } } switch obj.State { @@ -469,9 +484,10 @@ func (obj *CloudflareDNSRes) buildRecordParam() (any, error) { case "SRV": param := dns.SRVRecordParam{ - Name: cloudflare.F(obj.RecordName), - Type: cloudflare.F(dns.SRVRecordTypeSRV), - TTL: cloudflare.F(ttl), + Name: cloudflare.F(obj.RecordName), + Type: cloudflare.F(dns.SRVRecordTypeSRV), + Content: cloudflare.F(obj.Content), + TTL: cloudflare.F(ttl), } if obj.Proxied != nil { param.Proxied = cloudflare.F(*obj.Proxied) @@ -641,7 +657,8 @@ func (obj *CloudflareDNSRes) purgeCheckApply(ctx context.Context, apply bool) (b } if cfRes.Zone == obj.Zone { - recordKey := fmt.Sprintf("%s:%s", cfRes.RecordName, cfRes.Type) + recordKey := fmt.Sprintf("%s:%s:%s", cfRes.RecordName, cfRes.Type, + cfRes.Content) excludes[recordKey] = true } } @@ -649,7 +666,8 @@ func (obj *CloudflareDNSRes) purgeCheckApply(ctx context.Context, apply bool) (b checkOK := true for _, record := range allRecords { - recordKey := fmt.Sprintf("%s:%s", record.Name, record.Type) + recordKey := fmt.Sprintf("%s:%s:%s", record.Name, record.Type, + record.Content) if excludes[recordKey] { continue @@ -682,3 +700,25 @@ func (obj *CloudflareDNSRes) GraphQueryAllowed(opts ...engine.GraphQueryableOpti } return nil } + +// matchesRecordName checks if a record name from the API matches our desired record name. +// Handles both FQDN (www.example.com) and short form (www) comparisons. +func (obj *CloudflareDNSRes) matchesRecordName(apiRecordName string) bool { + desired := obj.normalizeRecordName(obj.RecordName) + actual := obj.normalizeRecordName(apiRecordName) + return desired == actual +} + +// normalizeRecordName converts a record name to a consistent format for comparison. +// Converts to FQDN format (e.g., "www" -> "www.example.com", "@" -> "example.com") +func (obj *CloudflareDNSRes) normalizeRecordName(name string) string { + if name == "@" || name == obj.Zone { + return obj.Zone + } + + if strings.HasSuffix(name, "."+obj.Zone) || name == obj.Zone { + return name + } + + return name + "." + obj.Zone +}