This ensures that docstring comments are wrapped to 80 chars. ffrank seemed to be making this mistake far too often, and it's a silly thing to look for manually. As it turns out, I've made it too, as have many others. Now we have a test that checks for most cases. There are still a few stray cases that aren't checked automatically, but this can be improved upon if someone is motivated to do so. Before anyone complains about the 80 character limit: this only checks docstring comments, not source code length or inline source code comments. There's no excuse for having docstrings that are badly reflowed or over 80 chars, particularly if you have an automated test.
455 lines
15 KiB
Go
455 lines
15 KiB
Go
// Mgmt
|
|
// Copyright (C) 2013-2020+ James Shubin and the project contributors
|
|
// Written by James Shubin <james@shubin.ca> and the project contributors
|
|
//
|
|
// This program is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// This program is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU General Public License
|
|
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
package util
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/gob"
|
|
"fmt"
|
|
"os"
|
|
"os/user"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/purpleidea/mgmt/engine"
|
|
"github.com/purpleidea/mgmt/lang/types"
|
|
"github.com/purpleidea/mgmt/util/errwrap"
|
|
|
|
"github.com/godbus/dbus"
|
|
)
|
|
|
|
const (
|
|
// StructTag is the key we use in struct field names for key mapping.
|
|
StructTag = "lang"
|
|
// DBusInterface is the dbus interface that contains genereal methods.
|
|
DBusInterface = "org.freedesktop.DBus"
|
|
// DBusAddMatch is the dbus method to receive a subset of dbus broadcast
|
|
// signals.
|
|
DBusAddMatch = DBusInterface + ".AddMatch"
|
|
// DBusRemoveMatch is the dbus method to remove a previously defined
|
|
// AddMatch rule.
|
|
DBusRemoveMatch = DBusInterface + ".RemoveMatch"
|
|
// DBusSystemd1Path is the base systemd1 path.
|
|
DBusSystemd1Path = "/org/freedesktop/systemd1"
|
|
// DBusSystemd1Iface is the base systemd1 interface.
|
|
DBusSystemd1Iface = "org.freedesktop.systemd1"
|
|
// DBusSystemd1ManagerIface is the systemd manager interface used for
|
|
// interfacing with systemd units.
|
|
DBusSystemd1ManagerIface = DBusSystemd1Iface + ".Manager"
|
|
// DBusRestartUnit is the dbus method for restarting systemd units.
|
|
DBusRestartUnit = DBusSystemd1ManagerIface + ".RestartUnit"
|
|
// DBusStopUnit is the dbus method for stopping systemd units.
|
|
DBusStopUnit = DBusSystemd1ManagerIface + ".StopUnit"
|
|
// DBusSignalJobRemoved is the name of the dbus signal that produces a
|
|
// message when a dbus job is done (or has errored.)
|
|
DBusSignalJobRemoved = "JobRemoved"
|
|
)
|
|
|
|
// ResPathUID returns a unique resource UID based on its name and kind. It's
|
|
// safe to use as a token in a path, and as a result has no slashes in it.
|
|
func ResPathUID(res engine.Res) string {
|
|
// res.Name() is NOT sufficiently unique to use as a UID here, because:
|
|
// a name of: /tmp/mgmt/foo is /tmp-mgmt-foo and
|
|
// a name of: /tmp/mgmt-foo -> /tmp-mgmt-foo if we replace slashes.
|
|
// As a result, we base64 encode (but without slashes).
|
|
name := strings.Replace(res.Name(), "/", "-", -1) // TODO: use ReplaceAll in 1.12
|
|
if os.PathSeparator != '/' { // lol windows?
|
|
name = strings.Replace(name, string(os.PathSeparator), "-", -1) // TODO: use ReplaceAll in 1.12
|
|
}
|
|
b := []byte(res.Name())
|
|
encoded := base64.URLEncoding.EncodeToString(b)
|
|
// Add the safe name on so that it's easier to identify by name...
|
|
return fmt.Sprintf("%s-%s+%s", res.Kind(), encoded, name)
|
|
}
|
|
|
|
// ResToB64 encodes a resource to a base64 encoded string (after serialization).
|
|
func ResToB64(res engine.Res) (string, error) {
|
|
b := bytes.Buffer{}
|
|
e := gob.NewEncoder(&b)
|
|
err := e.Encode(&res) // pass with &
|
|
if err != nil {
|
|
return "", errwrap.Wrapf(err, "gob failed to encode")
|
|
}
|
|
return base64.StdEncoding.EncodeToString(b.Bytes()), nil
|
|
}
|
|
|
|
// B64ToRes decodes a resource from a base64 encoded string (after
|
|
// deserialization).
|
|
func B64ToRes(str string) (engine.Res, error) {
|
|
var output interface{}
|
|
bb, err := base64.StdEncoding.DecodeString(str)
|
|
if err != nil {
|
|
return nil, errwrap.Wrapf(err, "base64 failed to decode")
|
|
}
|
|
b := bytes.NewBuffer(bb)
|
|
d := gob.NewDecoder(b)
|
|
if err := d.Decode(&output); err != nil { // pass with &
|
|
return nil, errwrap.Wrapf(err, "gob failed to decode")
|
|
}
|
|
res, ok := output.(engine.Res)
|
|
if !ok {
|
|
return nil, fmt.Errorf("output `%v` is not a Res", output)
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
// StructTagToFieldName returns a mapping from recommended alias to actual field
|
|
// name. It returns an error if it finds a collision. It uses the `lang` tags.
|
|
// It must be passed a ptr to a struct or it will error.
|
|
func StructTagToFieldName(stptr interface{}) (map[string]string, error) {
|
|
// TODO: fallback to looking up yaml tags, although harder to parse
|
|
result := make(map[string]string) // `lang` field tag -> field name
|
|
if stptr == nil {
|
|
return nil, fmt.Errorf("got nil input instead of ptr to struct")
|
|
}
|
|
typ := reflect.TypeOf(stptr)
|
|
if k := typ.Kind(); k != reflect.Ptr { // we only look at *Struct's
|
|
return nil, fmt.Errorf("input is not a ptr, got: %+v", k)
|
|
}
|
|
st := typ.Elem() // elem for ptr to struct (dereference the pointer)
|
|
if k := st.Kind(); k != reflect.Struct { // this should be a struct now
|
|
return nil, fmt.Errorf("input doesn't point to a struct, got: %+v", k)
|
|
}
|
|
|
|
for i := 0; i < st.NumField(); i++ {
|
|
field := st.Field(i)
|
|
name := field.Name
|
|
// if !ok, then nothing is found
|
|
if alias, ok := field.Tag.Lookup(StructTag); ok { // golang 1.7+
|
|
if val, exists := result[alias]; exists {
|
|
return nil, fmt.Errorf("field `%s` uses the same key `%s` as field `%s`", name, alias, val)
|
|
}
|
|
// empty string ("") is a valid value
|
|
if alias != "" {
|
|
result[alias] = name
|
|
}
|
|
}
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// StructFieldCompat returns whether a send struct and key is compatible with a
|
|
// recv struct and key. This inputs must both be a ptr to a string, and a valid
|
|
// key that can be found in the struct tag.
|
|
// TODO: add a bool to decide if *string to string or string to *string is okay.
|
|
func StructFieldCompat(st1 interface{}, key1 string, st2 interface{}, key2 string) error {
|
|
m1, err := StructTagToFieldName(st1)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
k1, exists := m1[key1]
|
|
if !exists {
|
|
return fmt.Errorf("key not found in send struct")
|
|
}
|
|
|
|
m2, err := StructTagToFieldName(st2)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
k2, exists := m2[key2]
|
|
if !exists {
|
|
return fmt.Errorf("key not found in recv struct")
|
|
}
|
|
|
|
obj1 := reflect.Indirect(reflect.ValueOf(st1))
|
|
//type1 := obj1.Type()
|
|
value1 := obj1.FieldByName(k1)
|
|
kind1 := value1.Kind()
|
|
|
|
obj2 := reflect.Indirect(reflect.ValueOf(st2))
|
|
//type2 := obj2.Type()
|
|
value2 := obj2.FieldByName(k2)
|
|
kind2 := value2.Kind()
|
|
|
|
if kind1 != kind2 {
|
|
return fmt.Errorf("kind mismatch between %s and %s", kind1, kind2)
|
|
}
|
|
|
|
if t1, t2 := value1.Type(), value2.Type(); t1 != t2 {
|
|
return fmt.Errorf("type mismatch between %s and %s", t1, t2)
|
|
}
|
|
|
|
if !value2.CanSet() { // if we can't set, then this is pointless!
|
|
return fmt.Errorf("can't set")
|
|
}
|
|
|
|
// if we can't interface, we can't compare...
|
|
if !value1.CanInterface() {
|
|
return fmt.Errorf("can't interface the send")
|
|
}
|
|
if !value2.CanInterface() {
|
|
return fmt.Errorf("can't interface the recv")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// LowerStructFieldNameToFieldName returns a mapping from the lower case version
|
|
// of each field name to the actual field name. It only returns public fields.
|
|
// It returns an error if it finds a collision.
|
|
func LowerStructFieldNameToFieldName(res engine.Res) (map[string]string, error) {
|
|
result := make(map[string]string) // lower field name -> field name
|
|
st := reflect.TypeOf(res).Elem() // elem for ptr to res
|
|
for i := 0; i < st.NumField(); i++ {
|
|
field := st.Field(i)
|
|
name := field.Name
|
|
|
|
if strings.Title(name) != name { // must have been a priv field
|
|
continue
|
|
}
|
|
|
|
if alias := strings.ToLower(name); alias != "" {
|
|
if val, exists := result[alias]; exists {
|
|
return nil, fmt.Errorf("field `%s` uses the same key `%s` as field `%s`", name, alias, val)
|
|
}
|
|
result[alias] = name
|
|
}
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// LangFieldNameToStructFieldName returns the mapping from lang (AST) field
|
|
// names to field name as used in the struct. The logic here is a bit strange;
|
|
// if the resource has struct tags, then it uses those, otherwise it falls back
|
|
// to using the lower case versions of things. It might be clever to combine the
|
|
// two so that tagged fields are used as such, and others are used in lowercase,
|
|
// but this is currently not implemented.
|
|
// TODO: should this behaviour be changed?
|
|
func LangFieldNameToStructFieldName(kind string) (map[string]string, error) {
|
|
res, err := engine.NewResource(kind)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
mapping, err := StructTagToFieldName(res)
|
|
if err != nil {
|
|
return nil, errwrap.Wrapf(err, "resource kind `%s` has bad field mapping", kind)
|
|
}
|
|
if len(mapping) == 0 { // if no `lang` tags exist, get them automatically
|
|
mapping, err = LowerStructFieldNameToFieldName(res)
|
|
if err != nil {
|
|
return nil, errwrap.Wrapf(err, "resource kind `%s` has bad automatic field mapping", kind)
|
|
}
|
|
}
|
|
|
|
return mapping, nil // lang field name -> field name
|
|
}
|
|
|
|
// StructKindToFieldNameTypeMap returns a map from field name to expected type
|
|
// in the lang type system.
|
|
func StructKindToFieldNameTypeMap(kind string) (map[string]*types.Type, error) {
|
|
res, err := engine.NewResource(kind)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sv := reflect.ValueOf(res).Elem() // pointer to struct, then struct
|
|
if k := sv.Kind(); k != reflect.Struct {
|
|
return nil, fmt.Errorf("expected struct, got: %s", k)
|
|
}
|
|
|
|
result := make(map[string]*types.Type)
|
|
|
|
st := reflect.TypeOf(res).Elem() // pointer to struct, then struct
|
|
for i := 0; i < st.NumField(); i++ {
|
|
field := st.Field(i)
|
|
name := field.Name
|
|
// TODO: in future, skip over fields that don't have a `lang` tag
|
|
//if name == "Base" { // TODO: hack!!!
|
|
// continue
|
|
//}
|
|
|
|
typ, err := types.TypeOf(field.Type)
|
|
// some types (eg complex64) aren't convertible, so skip for now...
|
|
if err != nil {
|
|
continue
|
|
//return nil, errwrap.Wrapf(err, "could not identify type of field `%s`", name)
|
|
}
|
|
result[name] = typ
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// LangFieldNameToStructType returns the mapping from lang (AST) field names,
|
|
// and the expected type in our type system for each.
|
|
func LangFieldNameToStructType(kind string) (map[string]*types.Type, error) {
|
|
// returns a mapping between fieldName and expected *types.Type
|
|
fieldNameTypMap, err := StructKindToFieldNameTypeMap(kind)
|
|
if err != nil {
|
|
return nil, errwrap.Wrapf(err, "could not determine types for `%s` resource", kind)
|
|
}
|
|
|
|
mapping, err := LangFieldNameToStructFieldName(kind)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// transform from field name to tag name
|
|
typMap := make(map[string]*types.Type)
|
|
for name, typ := range fieldNameTypMap {
|
|
if strings.Title(name) != name {
|
|
continue // skip private fields
|
|
}
|
|
|
|
found := false
|
|
for k, v := range mapping {
|
|
if v != name {
|
|
continue
|
|
}
|
|
// found
|
|
if found { // previously found!
|
|
return nil, fmt.Errorf("duplicate mapping for: %s", name)
|
|
}
|
|
typMap[k] = typ
|
|
found = true // :)
|
|
}
|
|
if !found {
|
|
return nil, fmt.Errorf("could not find mapping for: %s", name)
|
|
}
|
|
}
|
|
|
|
return typMap, nil
|
|
}
|
|
|
|
// GetUID returns the UID of an user. It supports an UID or an username. Caller
|
|
// should first check user is not empty. It will return an error if it can't
|
|
// lookup the UID or username.
|
|
func GetUID(username string) (int, error) {
|
|
userObj, err := user.LookupId(username)
|
|
if err == nil {
|
|
return strconv.Atoi(userObj.Uid)
|
|
}
|
|
|
|
userObj, err = user.Lookup(username)
|
|
if err == nil {
|
|
return strconv.Atoi(userObj.Uid)
|
|
}
|
|
|
|
return -1, errwrap.Wrapf(err, "user lookup error (%s)", username)
|
|
}
|
|
|
|
// GetGID returns the GID of a group. It supports a GID or a group name. Caller
|
|
// should first check group is not empty. It will return an error if it can't
|
|
// lookup the GID or group name.
|
|
func GetGID(group string) (int, error) {
|
|
groupObj, err := user.LookupGroupId(group)
|
|
if err == nil {
|
|
return strconv.Atoi(groupObj.Gid)
|
|
}
|
|
|
|
groupObj, err = user.LookupGroup(group)
|
|
if err == nil {
|
|
return strconv.Atoi(groupObj.Gid)
|
|
}
|
|
|
|
return -1, errwrap.Wrapf(err, "group lookup error (%s)", group)
|
|
}
|
|
|
|
// RestartUnit resarts the given dbus unit and waits for it to finish starting.
|
|
func RestartUnit(ctx context.Context, conn *dbus.Conn, unit string) error {
|
|
return unitStateAction(ctx, conn, unit, DBusRestartUnit)
|
|
}
|
|
|
|
// StopUnit stops the given dbus unit and waits for it to finish stopping.
|
|
func StopUnit(ctx context.Context, conn *dbus.Conn, unit string) error {
|
|
return unitStateAction(ctx, conn, unit, DBusStopUnit)
|
|
}
|
|
|
|
// unitStateAction is a helper function to perform state actions on systemd
|
|
// units. It waits for the requested job to be complete before it returns.
|
|
func unitStateAction(ctx context.Context, conn *dbus.Conn, unit, action string) error {
|
|
// Add a dbus rule to watch the systemd1 JobRemoved signal, used to wait
|
|
// until the job completes.
|
|
args := []string{
|
|
"type='signal'",
|
|
fmt.Sprintf("path='%s'", DBusSystemd1Path),
|
|
fmt.Sprintf("interface='%s'", DBusSystemd1ManagerIface),
|
|
fmt.Sprintf("member='%s'", DBusSignalJobRemoved),
|
|
fmt.Sprintf("arg2='%s'", unit),
|
|
}
|
|
// match dbus messages
|
|
if call := conn.BusObject().Call(DBusAddMatch, 0, strings.Join(args, ",")); call.Err != nil {
|
|
return errwrap.Wrapf(call.Err, "error creating dbus call")
|
|
}
|
|
defer conn.BusObject().Call(DBusRemoveMatch, 0, args) // ignore the error
|
|
|
|
// channel for godbus signal
|
|
ch := make(chan *dbus.Signal)
|
|
defer close(ch)
|
|
// subscribe the channel to the signal
|
|
conn.Signal(ch)
|
|
defer conn.RemoveSignal(ch)
|
|
|
|
// perform requested action on specified unit
|
|
sd1 := conn.Object(DBusSystemd1Iface, dbus.ObjectPath(DBusSystemd1Path))
|
|
if call := sd1.Call(action, 0, unit, "fail"); call.Err != nil {
|
|
return errwrap.Wrapf(call.Err, "error stopping unit: %s", unit)
|
|
}
|
|
|
|
// wait for the job to be removed, indicating completion
|
|
select {
|
|
case event, ok := <-ch:
|
|
if !ok {
|
|
return fmt.Errorf("channel closed unexpectedly")
|
|
}
|
|
if event.Body[3] != "done" {
|
|
return fmt.Errorf("unexpected job status: %s", event.Body[3])
|
|
}
|
|
case <-ctx.Done():
|
|
return fmt.Errorf("action %s on %s failed due to context timeout", action, unit)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// autoEdgeCombiner holds the state of the auto edge generator.
|
|
type autoEdgeCombiner struct {
|
|
ae []engine.AutoEdge
|
|
ptr int
|
|
}
|
|
|
|
// Next returns the next automatic edge.
|
|
func (obj *autoEdgeCombiner) Next() []engine.ResUID {
|
|
if len(obj.ae) <= obj.ptr {
|
|
panic("shouldn't be called anymore!")
|
|
}
|
|
return obj.ae[obj.ptr].Next() // return the next edge
|
|
}
|
|
|
|
// Test takes the output of the last call to Next() and outputs true if we
|
|
// should continue.
|
|
func (obj *autoEdgeCombiner) Test(input []bool) bool {
|
|
if !obj.ae[obj.ptr].Test(input) {
|
|
obj.ptr++ // match found, on to the next
|
|
}
|
|
return len(obj.ae) > obj.ptr // are there any auto edges left?
|
|
}
|
|
|
|
// AutoEdgeCombiner takes any number of AutoEdge structs, and combines them into
|
|
// a single one, so that the logic from each one can be built separately, and
|
|
// then combined using this utility. This makes implementing different AutoEdge
|
|
// generators much easier. This respects the Next() and Test() API, and ratchets
|
|
// through each AutoEdge entry until they have all run their course.
|
|
func AutoEdgeCombiner(ae ...engine.AutoEdge) (engine.AutoEdge, error) {
|
|
return &autoEdgeCombiner{
|
|
ae: ae,
|
|
}, nil
|
|
}
|