2 Commits

Author SHA1 Message Date
Lourenço Vales
1787e44503 added more robust validation; small changes 2025-10-05 21:22:48 +02:00
Lourenço Vales
bc4f7b7309 fixed slight error in validation 2025-10-05 20:59:27 +02:00

View File

@@ -32,6 +32,7 @@ package resources
import (
"context"
"fmt"
"strings"
"github.com/purpleidea/mgmt/engine"
"github.com/purpleidea/mgmt/engine/traits"
@@ -73,6 +74,10 @@ type CloudflareDNSRes struct {
// Type (e.g., IP address for A records, hostname for CNAME records).
Content string `lang:"content"`
// Data is a value that's specific for SRV records, containing the priority,
// weight, port, and SRV targets.
Data *dns.SRVRecordData `lang:"srv_data"`
// Priority is the priority value for records that support it (e.g., MX
// records). This is a pointer to distinguish between an unset value and
// a zero value.
@@ -152,6 +157,10 @@ func (obj *CloudflareDNSRes) Validate() error {
return fmt.Errorf("content is required when state is 'exists'")
}
if obj.Type == "MX" && obj.Priority == nil {
return fmt.Errorf("priority is required for MX records")
}
if obj.MetaParams().Poll == 0 || obj.MetaParams().Poll < 1 { // CF accepts ~4req/s so this is good enough
return fmt.Errorf("cloudflare:dns requires polling, set Meta:poll param (e.g., 60 seconds), min. 1s")
}
@@ -233,7 +242,8 @@ func (obj *CloudflareDNSRes) CheckApply(ctx context.Context, apply bool) (bool,
}
}
// List existing records
// we're using `contains` so as to get the candidates, as `exact` might not
// give the expected results depending on how the user specified it.
listParams := dns.RecordListParams{
ZoneID: cloudflare.F(obj.zoneID),
Name: cloudflare.F(dns.RecordListParamsName{
@@ -247,10 +257,15 @@ func (obj *CloudflareDNSRes) CheckApply(ctx context.Context, apply bool) (bool,
return false, errwrap.Wrapf(err, "failed to list DNS records")
}
recordExists := len(recordList.Result) > 0
// here we filter to find the exact match
recordExists := false
var record dns.RecordResponse
if recordExists {
record = recordList.Result[0]
for _, r := range recordList.Result {
if obj.matchesRecordName(r.Name) {
record = r
recordExists = true
break
}
}
switch obj.State {
@@ -319,7 +334,11 @@ func (obj *CloudflareDNSRes) Cmp(r engine.Res) error {
return fmt.Errorf("apitoken differs")
}
if obj.Proxied != res.Proxied {
if (obj.Proxied == nil) != (res.Proxied == nil) {
return fmt.Errorf("proxied values differ")
}
if obj.Proxied != nil && *obj.Proxied != *res.Proxied {
return fmt.Errorf("proxied values differ")
}
@@ -467,6 +486,7 @@ func (obj *CloudflareDNSRes) buildRecordParam() (any, error) {
param := dns.SRVRecordParam{
Name: cloudflare.F(obj.RecordName),
Type: cloudflare.F(dns.SRVRecordTypeSRV),
Content: cloudflare.F(obj.Content),
TTL: cloudflare.F(ttl),
}
if obj.Proxied != nil {
@@ -637,7 +657,8 @@ func (obj *CloudflareDNSRes) purgeCheckApply(ctx context.Context, apply bool) (b
}
if cfRes.Zone == obj.Zone {
recordKey := fmt.Sprintf("%s:%s", cfRes.RecordName, cfRes.Type)
recordKey := fmt.Sprintf("%s:%s:%s", cfRes.RecordName, cfRes.Type,
cfRes.Content)
excludes[recordKey] = true
}
}
@@ -645,7 +666,8 @@ func (obj *CloudflareDNSRes) purgeCheckApply(ctx context.Context, apply bool) (b
checkOK := true
for _, record := range allRecords {
recordKey := fmt.Sprintf("%s:%s", record.Name, record.Type)
recordKey := fmt.Sprintf("%s:%s:%s", record.Name, record.Type,
record.Content)
if excludes[recordKey] {
continue
@@ -678,3 +700,25 @@ func (obj *CloudflareDNSRes) GraphQueryAllowed(opts ...engine.GraphQueryableOpti
}
return nil
}
// matchesRecordName checks if a record name from the API matches our desired record name.
// Handles both FQDN (www.example.com) and short form (www) comparisons.
func (obj *CloudflareDNSRes) matchesRecordName(apiRecordName string) bool {
desired := obj.normalizeRecordName(obj.RecordName)
actual := obj.normalizeRecordName(apiRecordName)
return desired == actual
}
// normalizeRecordName converts a record name to a consistent format for comparison.
// Converts to FQDN format (e.g., "www" -> "www.example.com", "@" -> "example.com")
func (obj *CloudflareDNSRes) normalizeRecordName(name string) string {
if name == "@" || name == obj.Zone {
return obj.Zone
}
if strings.HasSuffix(name, "."+obj.Zone) || name == obj.Zone {
return name
}
return name + "." + obj.Zone
}