added more robust validation; small changes

This commit is contained in:
Lourenço Vales
2025-10-05 21:22:48 +02:00
parent bc4f7b7309
commit 1787e44503

View File

@@ -32,6 +32,7 @@ package resources
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"github.com/purpleidea/mgmt/engine" "github.com/purpleidea/mgmt/engine"
"github.com/purpleidea/mgmt/engine/traits" "github.com/purpleidea/mgmt/engine/traits"
@@ -73,6 +74,10 @@ type CloudflareDNSRes struct {
// Type (e.g., IP address for A records, hostname for CNAME records). // Type (e.g., IP address for A records, hostname for CNAME records).
Content string `lang:"content"` Content string `lang:"content"`
// Data is a value that's specific for SRV records, containing the priority,
// weight, port, and SRV targets.
Data *dns.SRVRecordData `lang:"srv_data"`
// Priority is the priority value for records that support it (e.g., MX // Priority is the priority value for records that support it (e.g., MX
// records). This is a pointer to distinguish between an unset value and // records). This is a pointer to distinguish between an unset value and
// a zero value. // a zero value.
@@ -152,8 +157,12 @@ func (obj *CloudflareDNSRes) Validate() error {
return fmt.Errorf("content is required when state is 'exists'") return fmt.Errorf("content is required when state is 'exists'")
} }
if obj.Type == "MX" && obj.Priority == nil {
return fmt.Errorf("priority is required for MX records")
}
if obj.MetaParams().Poll == 0 || obj.MetaParams().Poll < 1 { // CF accepts ~4req/s so this is good enough if obj.MetaParams().Poll == 0 || obj.MetaParams().Poll < 1 { // CF accepts ~4req/s so this is good enough
return fmt.Errorf("cloudflare:dns requires polling, set Meta:poll param (e.g., 60 seconds), min. 1s") return fmt.Errorf("cloudflare:dns requires polling, set Meta:poll param (e.g., 60 seconds), min. 1s")
} }
return nil return nil
@@ -233,7 +242,8 @@ func (obj *CloudflareDNSRes) CheckApply(ctx context.Context, apply bool) (bool,
} }
} }
// List existing records // we're using `contains` so as to get the candidates, as `exact` might not
// give the expected results depending on how the user specified it.
listParams := dns.RecordListParams{ listParams := dns.RecordListParams{
ZoneID: cloudflare.F(obj.zoneID), ZoneID: cloudflare.F(obj.zoneID),
Name: cloudflare.F(dns.RecordListParamsName{ Name: cloudflare.F(dns.RecordListParamsName{
@@ -247,10 +257,15 @@ func (obj *CloudflareDNSRes) CheckApply(ctx context.Context, apply bool) (bool,
return false, errwrap.Wrapf(err, "failed to list DNS records") return false, errwrap.Wrapf(err, "failed to list DNS records")
} }
recordExists := len(recordList.Result) > 0 // here we filter to find the exact match
recordExists := false
var record dns.RecordResponse var record dns.RecordResponse
if recordExists { for _, r := range recordList.Result {
record = recordList.Result[0] if obj.matchesRecordName(r.Name) {
record = r
recordExists = true
break
}
} }
switch obj.State { switch obj.State {
@@ -469,9 +484,10 @@ func (obj *CloudflareDNSRes) buildRecordParam() (any, error) {
case "SRV": case "SRV":
param := dns.SRVRecordParam{ param := dns.SRVRecordParam{
Name: cloudflare.F(obj.RecordName), Name: cloudflare.F(obj.RecordName),
Type: cloudflare.F(dns.SRVRecordTypeSRV), Type: cloudflare.F(dns.SRVRecordTypeSRV),
TTL: cloudflare.F(ttl), Content: cloudflare.F(obj.Content),
TTL: cloudflare.F(ttl),
} }
if obj.Proxied != nil { if obj.Proxied != nil {
param.Proxied = cloudflare.F(*obj.Proxied) param.Proxied = cloudflare.F(*obj.Proxied)
@@ -641,7 +657,8 @@ func (obj *CloudflareDNSRes) purgeCheckApply(ctx context.Context, apply bool) (b
} }
if cfRes.Zone == obj.Zone { if cfRes.Zone == obj.Zone {
recordKey := fmt.Sprintf("%s:%s", cfRes.RecordName, cfRes.Type) recordKey := fmt.Sprintf("%s:%s:%s", cfRes.RecordName, cfRes.Type,
cfRes.Content)
excludes[recordKey] = true excludes[recordKey] = true
} }
} }
@@ -649,7 +666,8 @@ func (obj *CloudflareDNSRes) purgeCheckApply(ctx context.Context, apply bool) (b
checkOK := true checkOK := true
for _, record := range allRecords { for _, record := range allRecords {
recordKey := fmt.Sprintf("%s:%s", record.Name, record.Type) recordKey := fmt.Sprintf("%s:%s:%s", record.Name, record.Type,
record.Content)
if excludes[recordKey] { if excludes[recordKey] {
continue continue
@@ -682,3 +700,25 @@ func (obj *CloudflareDNSRes) GraphQueryAllowed(opts ...engine.GraphQueryableOpti
} }
return nil return nil
} }
// matchesRecordName checks if a record name from the API matches our desired record name.
// Handles both FQDN (www.example.com) and short form (www) comparisons.
func (obj *CloudflareDNSRes) matchesRecordName(apiRecordName string) bool {
desired := obj.normalizeRecordName(obj.RecordName)
actual := obj.normalizeRecordName(apiRecordName)
return desired == actual
}
// normalizeRecordName converts a record name to a consistent format for comparison.
// Converts to FQDN format (e.g., "www" -> "www.example.com", "@" -> "example.com")
func (obj *CloudflareDNSRes) normalizeRecordName(name string) string {
if name == "@" || name == obj.Zone {
return obj.Zone
}
if strings.HasSuffix(name, "."+obj.Zone) || name == obj.Zone {
return name
}
return name + "." + obj.Zone
}