added more robust validation; small changes

This commit is contained in:
Lourenço Vales
2025-10-05 21:22:48 +02:00
parent bc4f7b7309
commit 1787e44503

View File

@@ -32,6 +32,7 @@ package resources
import (
"context"
"fmt"
"strings"
"github.com/purpleidea/mgmt/engine"
"github.com/purpleidea/mgmt/engine/traits"
@@ -73,6 +74,10 @@ type CloudflareDNSRes struct {
// Type (e.g., IP address for A records, hostname for CNAME records).
Content string `lang:"content"`
// Data is a value that's specific for SRV records, containing the priority,
// weight, port, and SRV targets.
Data *dns.SRVRecordData `lang:"srv_data"`
// Priority is the priority value for records that support it (e.g., MX
// records). This is a pointer to distinguish between an unset value and
// a zero value.
@@ -152,8 +157,12 @@ func (obj *CloudflareDNSRes) Validate() error {
return fmt.Errorf("content is required when state is 'exists'")
}
if obj.Type == "MX" && obj.Priority == nil {
return fmt.Errorf("priority is required for MX records")
}
if obj.MetaParams().Poll == 0 || obj.MetaParams().Poll < 1 { // CF accepts ~4req/s so this is good enough
return fmt.Errorf("cloudflare:dns requires polling, set Meta:poll param (e.g., 60 seconds), min. 1s")
return fmt.Errorf("cloudflare:dns requires polling, set Meta:poll param (e.g., 60 seconds), min. 1s")
}
return nil
@@ -233,7 +242,8 @@ func (obj *CloudflareDNSRes) CheckApply(ctx context.Context, apply bool) (bool,
}
}
// List existing records
// we're using `contains` so as to get the candidates, as `exact` might not
// give the expected results depending on how the user specified it.
listParams := dns.RecordListParams{
ZoneID: cloudflare.F(obj.zoneID),
Name: cloudflare.F(dns.RecordListParamsName{
@@ -247,10 +257,15 @@ func (obj *CloudflareDNSRes) CheckApply(ctx context.Context, apply bool) (bool,
return false, errwrap.Wrapf(err, "failed to list DNS records")
}
recordExists := len(recordList.Result) > 0
// here we filter to find the exact match
recordExists := false
var record dns.RecordResponse
if recordExists {
record = recordList.Result[0]
for _, r := range recordList.Result {
if obj.matchesRecordName(r.Name) {
record = r
recordExists = true
break
}
}
switch obj.State {
@@ -469,9 +484,10 @@ func (obj *CloudflareDNSRes) buildRecordParam() (any, error) {
case "SRV":
param := dns.SRVRecordParam{
Name: cloudflare.F(obj.RecordName),
Type: cloudflare.F(dns.SRVRecordTypeSRV),
TTL: cloudflare.F(ttl),
Name: cloudflare.F(obj.RecordName),
Type: cloudflare.F(dns.SRVRecordTypeSRV),
Content: cloudflare.F(obj.Content),
TTL: cloudflare.F(ttl),
}
if obj.Proxied != nil {
param.Proxied = cloudflare.F(*obj.Proxied)
@@ -641,7 +657,8 @@ func (obj *CloudflareDNSRes) purgeCheckApply(ctx context.Context, apply bool) (b
}
if cfRes.Zone == obj.Zone {
recordKey := fmt.Sprintf("%s:%s", cfRes.RecordName, cfRes.Type)
recordKey := fmt.Sprintf("%s:%s:%s", cfRes.RecordName, cfRes.Type,
cfRes.Content)
excludes[recordKey] = true
}
}
@@ -649,7 +666,8 @@ func (obj *CloudflareDNSRes) purgeCheckApply(ctx context.Context, apply bool) (b
checkOK := true
for _, record := range allRecords {
recordKey := fmt.Sprintf("%s:%s", record.Name, record.Type)
recordKey := fmt.Sprintf("%s:%s:%s", record.Name, record.Type,
record.Content)
if excludes[recordKey] {
continue
@@ -682,3 +700,25 @@ func (obj *CloudflareDNSRes) GraphQueryAllowed(opts ...engine.GraphQueryableOpti
}
return nil
}
// matchesRecordName checks if a record name from the API matches our desired record name.
// Handles both FQDN (www.example.com) and short form (www) comparisons.
func (obj *CloudflareDNSRes) matchesRecordName(apiRecordName string) bool {
desired := obj.normalizeRecordName(obj.RecordName)
actual := obj.normalizeRecordName(apiRecordName)
return desired == actual
}
// normalizeRecordName converts a record name to a consistent format for comparison.
// Converts to FQDN format (e.g., "www" -> "www.example.com", "@" -> "example.com")
func (obj *CloudflareDNSRes) normalizeRecordName(name string) string {
if name == "@" || name == obj.Zone {
return obj.Zone
}
if strings.HasSuffix(name, "."+obj.Zone) || name == obj.Zone {
return name
}
return name + "." + obj.Zone
}