Files
mgmt/lang/unification/util/util.go
James Shubin ad4eb86262 lang: unification: util: Add the core unification helpers
This adds the core unification helper functions that do the core work of
solving the invariants. This includes the actual Unify, OccursCheck, and
Extract which is sometimes known as "zonk".

A few other small functions are also included.

Co-authored-by: Samuel Gélineau <gelisam@gmail.com>
2024-07-01 16:06:58 -04:00

843 lines
24 KiB
Go

// Mgmt
// Copyright (C) 2013-2024+ James Shubin and the project contributors
// Written by James Shubin <james@shubin.ca> and the project contributors
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//
// Additional permission under GNU GPL version 3 section 7
//
// If you modify this program, or any covered work, by linking or combining it
// with embedded mcl code and modules (and that the embedded mcl code and
// modules which link with this program, contain a copy of their source code in
// the authoritative form) containing parts covered by the terms of any other
// license, the licensors of this program grant you additional permission to
// convey the resulting work. Furthermore, the licensors of this program grant
// the original author, James Shubin, additional permission to update this
// additional permission if he deems it necessary to achieve the goals of this
// additional permission.
// Package util contains some utility functions and algorithms which are useful
// for type unification.
package util
import (
"fmt"
"github.com/purpleidea/mgmt/lang/types"
)
// Unify takes two types and tries to make them equivalent. If they are both
// basic types without any unification variables, then we compare them directly,
// and this is equivalent to running the *types.Type Cmp method. This function
// works by drawing conclusions from the assertion that the two sides are equal:
// that a variable on the left must be equal to the sub-trees at the same
// position on the right, and similarly for variables on the right. This may
// modify the input types, copy them before use if this is an issue. If you only
// want to do a compare, you can safely use UnifyCmp.
func Unify(typ1, typ2 *types.Type) error {
if typ1 == nil || typ2 == nil {
return fmt.Errorf("nil type")
}
// Both types are real and don't contain any unification variables, so
// we just compare them directly. -- Actually don't, these could contain
// unification variables, so leave that to the recursive call below.
//if typ1.Uni == nil && typ2.Uni == nil {
// return typ1.Cmp(typ2)
//}
// Here we have one type that is a ?1 type, and the other one *might* be
// a full type or it might even contain a ?2 for example. It could be a
// [?2] or [[?2]] for example.
if typ1.Uni != nil && typ2.Uni == nil { // aka && typ2.Kind != nil
root := typ1.Uni.Find()
// We don't yet know anything about this unification variable.
if root.Data == nil {
if err := OccursCheck(root, typ2); err != nil {
return err
}
root.Data = typ2 // learn!
return nil
}
// otherwise, cmp root.Data with typ2
return Unify(root.Data, typ2)
}
// This is the same case as above, except it's the opposite scenario.
if typ1.Uni == nil && typ2.Uni != nil {
root := typ2.Uni.Find()
// We don't yet know anything about this unification variable.
if root.Data == nil {
if err := OccursCheck(root, typ1); err != nil {
return err
}
root.Data = typ1 // learn!
return nil
}
// otherwise, cmp root.Data with typ1
return Unify(root.Data, typ1)
}
// Both of these are of the form ?1 and ?2 so we compare them directly.
if typ1.Uni != nil && typ2.Uni != nil {
root1 := typ1.Uni.Find()
root2 := typ2.Uni.Find()
if root1.Data == nil && root2.Data == nil {
// We don't need a merge function to wrap the Union call
// because in this scenario, both data fields are empty!
root1.Union(root2) // merge!
return nil
}
if root1.Data == nil && root2.Data != nil {
return Unify(typ1, root2.Data)
}
if root1.Data != nil && root2.Data == nil {
return Unify(root1.Data, typ2)
}
// root1.Data != nil && root2.Data != nil
return Unify(root1.Data, root2.Data)
}
// At this point, we've handled all the special cases with the ?1, ?2
// unification variables, so we now expect the kind's to match to unify.
if k1, k2 := typ1.Kind, typ2.Kind; k1 != k2 {
return fmt.Errorf("type error: %v != %v", k1, k2)
}
if typ1.Kind == types.KindBool && typ2.Kind == types.KindBool {
return nil
}
if typ1.Kind == types.KindStr && typ2.Kind == types.KindStr {
return nil
}
if typ1.Kind == types.KindInt && typ2.Kind == types.KindInt {
return nil
}
if typ1.Kind == types.KindFloat && typ2.Kind == types.KindFloat {
return nil
}
if typ1.Kind == types.KindList && typ2.Kind == types.KindList {
if err := Unify(typ1.Val, typ2.Val); err != nil {
return err
}
//Unpack(typ1)
//Unpack(typ2)
//if typ1.Val.Uni != nil { // example of what Unpack() does
// typ1.Val = typ1.Val.Uni.Data // TODO: Should we find?
//}
//if typ2.Val.Uni != nil {
// typ2.Val = typ2.Val.Uni.Data // TODO: Should we find?
//}
return nil
}
if typ1.Kind == types.KindMap && typ2.Kind == types.KindMap {
if err := Unify(typ1.Key, typ2.Key); err != nil {
return err
}
if err := Unify(typ1.Val, typ2.Val); err != nil {
return err
}
//Unpack(typ1)
//Unpack(typ2)
return nil
}
if typ1.Kind == types.KindStruct && typ2.Kind == types.KindStruct {
if typ1.Map == nil || typ2.Map == nil {
panic("malformed struct type")
}
if len(typ1.Ord) != len(typ2.Ord) {
return fmt.Errorf("struct field count differs")
}
for i, k := range typ1.Ord {
if k != typ2.Ord[i] {
return fmt.Errorf("struct fields differ")
}
}
for _, k := range typ1.Ord { // loop map in deterministic order
t1, ok := typ1.Map[k]
if !ok {
panic("malformed struct order")
}
t2, ok := typ2.Map[k]
if !ok {
panic("malformed struct order")
}
//if t1 == nil || t2 == nil { // checked at the top
// panic("malformed struct field")
//}
if err := Unify(t1, t2); err != nil {
return err
}
}
//Unpack(typ1)
//Unpack(typ2)
return nil
}
if typ1.Kind == types.KindFunc && typ2.Kind == types.KindFunc {
if typ1.Map == nil || typ2.Map == nil {
panic("malformed func type")
}
if len(typ1.Ord) != len(typ2.Ord) {
return fmt.Errorf("func arg count differs")
}
// needed for strict cmp only...
//for i, k := range typ1.Ord {
// if k != typ2.Ord[i] {
// return fmt.Errorf("func arg differs")
// }
//}
//for _, k := range typ1.Ord { // loop map in deterministic order
// t1, ok := typ1.Map[k]
// if !ok {
// panic("malformed func order")
// }
// t2, ok := typ2.Map[k]
// if !ok {
// panic("malformed func order")
// }
// //if t1 == nil || t2 == nil { // checked at the top
// // panic("malformed func arg")
// //}
// if err := Unify(t1, t2); err != nil {
// return err
// }
//}
// if we're not comparing arg names, get the two lists of types
for i := 0; i < len(typ1.Ord); i++ {
t1, ok := typ1.Map[typ1.Ord[i]]
if !ok {
panic("malformed func order")
}
if t1 == nil {
panic("malformed func arg")
}
t2, ok := typ2.Map[typ2.Ord[i]]
if !ok {
panic("malformed func order")
}
if t2 == nil {
panic("malformed func arg")
}
if err := Unify(t1, t2); err != nil {
return err
}
}
if err := Unify(typ1.Out, typ2.Out); err != nil {
return err
}
//Unpack(typ1)
//Unpack(typ2)
return nil
}
// Unsure if this case is ever used.
if typ1.Kind == types.KindVariant && typ2.Kind == types.KindVariant {
// TODO: should we Unify typ1.Var with typ2.Var ?
if err := Unify(typ1.Var, typ2.Var); err != nil {
return err
}
//Unpack(typ1)
//Unpack(typ2)
return nil
}
// programming error
return fmt.Errorf("unhandled type case")
}
// OccursCheck determines if elem exists inside of this type. This is important
// so that we can avoid infinite self-referential types. If we find that it does
// occur, then we error. This: `?1 occurs-in [[?2]]` is what we're doing here!
// This function panics if you pass in a nil type, a malformed type, or an elem
// is nil or that has a .Data field that is populated. This is because if ?0 is
// equal to []?1, then to prevent infinite types, it does not suffice to check
// that ?0 does not occur in the list ([]?1), we must also check that []?1 does
// not occur. That check is harder (and slower) to implement. So we just don't
// implement it, and we ask the caller to only call OccursCheck in the easy case
// where ?0 is not yet equal to anything else.
func OccursCheck(elem *types.Elem, typ *types.Type) error {
if elem == nil {
panic("nil elem")
}
if typ == nil {
panic("nil type")
}
if typ.Uni != nil {
root1 := elem.Find()
root2 := typ.Uni.Find()
if root1 == root2 {
return fmt.Errorf("directly in the same set")
}
// We don't look at root1.Data as it's not important because we
// only call OccursCheck if root1.Data (which is the single
// representative for the elem set by union find) is nil.
// TODO: Move this check to the top of the function for safety?
if root1.Data != nil {
// programming error
panic("unexpected non-nil data")
}
if root2.Data == nil {
return nil // We don't occur again, we are done!
}
return OccursCheck(root1, root2.Data)
}
// Now we know that `typ.Uni == nil`. There could still be a ?1 inside
// of a recursive type .Val, .Key, etc field. Check through all of them.
if typ.Kind == types.KindBool {
return nil
}
if typ.Kind == types.KindStr {
return nil
}
if typ.Kind == types.KindInt {
return nil
}
if typ.Kind == types.KindFloat {
return nil
}
if typ.Kind == types.KindList {
return OccursCheck(elem, typ.Val)
}
if typ.Kind == types.KindMap {
if err := OccursCheck(elem, typ.Key); err != nil {
return err
}
return OccursCheck(elem, typ.Val)
}
if typ.Kind == types.KindStruct {
for _, k := range typ.Ord {
t, ok := typ.Map[k]
if !ok {
panic("malformed struct order")
}
//if t == nil { // checked at the top
// panic("malformed struct field")
//}
if err := OccursCheck(elem, t); err != nil {
return err
}
}
return nil
}
if typ.Kind == types.KindFunc {
for _, k := range typ.Ord {
t, ok := typ.Map[k]
if !ok {
panic("malformed func order")
}
//if t == nil { // checked at the top
// panic("malformed func field")
//}
if err := OccursCheck(elem, t); err != nil {
return err
}
}
return OccursCheck(elem, typ.Out)
}
// Unsure if this case is ever used.
if typ.Kind == types.KindVariant {
return OccursCheck(elem, typ.Var)
}
// programming error
panic("malformed type")
}
// Extract takes a solved type out of a unification variable data slot, and
// returns it from the input container type. If there is no contained
// unification variable, or no type in the unification variable data field, then
// this simply returns the input type without modification. Fields in the type
// tree that do not need to be copied, are not. If in any doubt copy the result.
// We *never* copy the actual unification variable itself. We keep the pointer!
// This public function looks at the whole type recursively, and will first
// extract any child unification variables before finishing at the top. This was
// written based on the realization that something of this shape was needed
// after looking at post-unification unification variables and realizing that
// the data needed to be "bubbled upwards". It turns out the GHC Haskell has a
// similar function for this which is called "zonk". "zonk" is the name which
// it uses for this transformation, whimsically claiming that "zonk" is named
// "after the sound it makes". (Thanks to Sam for the fun trivia!)
// TODO: Untested alternate version that copies everything if anything changes.
func Extract(typ *types.Type) *types.Type {
if typ == nil {
return typ // or panic?
}
if typ.Uni != nil {
return extract(typ) // call private helper, not a recursion!
}
switch typ.Kind {
case types.KindBool:
return typ
case types.KindStr:
return typ
case types.KindInt:
return typ
case types.KindFloat:
return typ
case types.KindList:
if typ.Val == nil {
panic("malformed list type")
}
val := Extract(typ.Val)
if val != typ.Val {
//t := UnifyCopy(typ) // don't copy at all for normal zonk
t := typ // bypass UnifyCopy for now
t.Val = val
return t
}
return typ
case types.KindMap:
if typ.Key == nil || typ.Val == nil {
panic("malformed map type")
}
key := Extract(typ.Key)
val := Extract(typ.Val)
if key != typ.Key || val != typ.Val {
//t := UnifyCopy(typ) // don't copy at all for normal zonk
t := typ // bypass UnifyCopy for now
t.Key = key
t.Val = val
return t
}
// Alternate version that copies everything if anything changes.
//key := Extract(typ.Key)
//val := Extract(typ.Val)
//if copied := key != typ.Key || val != typ.Val; copied {
// t := UnifyCopy(typ) // don't copy at all for normal zonk
// t.Key = key // assume
// t.Val = val
//
// if key == typ.Key { // val changed, so copy key
// t.Key = UnifyCopy(key)
// }
// if val == typ.Val { // key changed, so copy val
// t.Val = UnifyCopy(val)
// }
// return t
//}
return typ
case types.KindStruct: // {a bool; b int}
if typ.Map == nil {
panic("malformed struct type")
}
if len(typ.Map) != len(typ.Ord) {
panic("malformed struct length")
}
copied := false
m := make(map[string]*types.Type)
for _, k := range typ.Ord {
t, ok := typ.Map[k]
if !ok {
panic("malformed struct order")
}
if t == nil {
panic("malformed struct field")
}
m[k] = Extract(t)
if m[k] != t {
copied = true
}
}
if copied {
//t := UnifyCopy(typ) // don't copy at all for normal zonk
t := typ // bypass UnifyCopy for now
for _, k := range typ.Ord {
t.Map[k] = m[k]
// Alternate version that copies everything if
// anything changes.
//if m[k] == typ.Map[k] { // field changed, so copy
// t.Map[k] = UnifyCopy(m[k])
//}
}
return t
}
return typ
case types.KindFunc:
if typ.Map == nil {
panic("malformed func type")
}
if len(typ.Map) != len(typ.Ord) {
panic("malformed func length")
}
copied := false
m := make(map[string]*types.Type)
for _, k := range typ.Ord {
t, ok := typ.Map[k]
if !ok {
panic("malformed func order")
}
if t == nil {
panic("malformed func field")
}
m[k] = Extract(t)
if m[k] != t {
copied = true
}
}
//if typ.Out != nil {
out := Extract(typ.Out)
//}
if copied || out != typ.Out {
//t := UnifyCopy(typ) // don't copy at all for normal zonk
t := typ // bypass UnifyCopy for now
for _, k := range typ.Ord {
t.Map[k] = m[k]
// Alternate version that copies everything if
// anything changes.
//if m[k] == typ.Map[k] { // arg changed, so copy
// t.Map[k] = UnifyCopy(m[k])
//}
}
t.Out = out
// Alternate version that copies everything if anything
// changes.
//if out == typ.Out {
// t.Out = UnifyCopy(out) // out changed, so copy
//}
return t
}
return typ
case types.KindVariant:
v := Extract(typ.Var)
if v != typ.Var {
//t := UnifyCopy(typ) // don't copy at all for normal zonk
t := typ // bypass UnifyCopy for now
t.Var = v
return t
}
return typ
case types.KindUnification:
return extract(typ) // duplicate case that is handled at the top
}
panic("malformed type")
}
// extract takes a solved type out of a unification variable data slot, and
// returns it from the input container type. If there is no contained
// unification variable, or no type in the unification variable data field, then
// this simply returns the input type without modification. Fields in the type
// tree that do not need to be copied, are not. If in any doubt copy the result.
// This helper function only looks at the direct unification variable. This is
// harmless, because it would only extract over a nested unification variable if
// this types data field was non-nil, however the recursive version named
// Extract is probably what you actually want to call.
func extract(typ *types.Type) *types.Type {
if typ.Uni == nil {
return typ
}
// Remember, the types of every Elem in the set are known to be equal to
// each other, so there is only one .Data for the whole set, not one per
// Elem. Some Union-Find libraries only allow you to store one piece of
// data per set, but our library allows each Elem to hold a different
// .Data value; that's not what we want, so we follow the convention
// that only the root of each set has meaningful .Data, the rest have
// junk, outdated .Data values.
root := typ.Uni.Find() // We should definitely call find.
//root := typ.Uni // wrong!
if root.Data == nil {
//return typ // normal implementation
// Returning this way instead makes it easier for debugging, and
// is legal to leave in forever too. (Says Sam!)
return &types.Type{
Kind: types.KindUnification,
Uni: root,
}
}
// NOTE: Not doing the recursion here was an earlier major bug of mine!
return Extract(root.Data) // return it (but it's important we recurse!)
}
// Unpack takes a solved type out of a unification variable data slot, and
// unpacks it over that unification variable to turn the type into a normal one.
// This returns false if no action was taken, or true if it modified the input.
// This public function looks at the whole type recursively, and will first
// unpack any child unification variables before finishing at the top.
func Unpack(typ *types.Type) bool {
if typ == nil {
return false
}
if typ.Uni != nil {
return unpack(typ) // call private helper, not a recursion!
}
switch typ.Kind {
case types.KindBool:
return false
case types.KindStr:
return false
case types.KindInt:
return false
case types.KindFloat:
return false
case types.KindList:
if typ.Val == nil {
panic("malformed list type")
}
return Unpack(typ.Val)
case types.KindMap:
if typ.Key == nil || typ.Val == nil {
panic("malformed map type")
}
b := Unpack(typ.Key)
if Unpack(typ.Val) {
b = true
}
return b
case types.KindStruct: // {a bool; b int}
if typ.Map == nil {
panic("malformed struct type")
}
if len(typ.Map) != len(typ.Ord) {
panic("malformed struct length")
}
b := false
for _, k := range typ.Ord {
t, ok := typ.Map[k]
if !ok {
panic("malformed struct order")
}
if t == nil {
panic("malformed struct field")
}
if Unpack(t) {
b = true
}
}
return b
case types.KindFunc:
if typ.Map == nil {
panic("malformed func type")
}
if len(typ.Map) != len(typ.Ord) {
panic("malformed func length")
}
b := false
for _, k := range typ.Ord {
t, ok := typ.Map[k]
if !ok {
panic("malformed func order")
}
if t == nil {
panic("malformed func field")
}
if Unpack(t) {
b = true
}
}
if typ.Out != nil {
if Unpack(typ.Out) {
b = true
}
}
return b
case types.KindVariant:
return Unpack(typ.Var)
case types.KindUnification:
return unpack(typ) // duplicate case that is handled at the top
}
panic("malformed type")
}
// unpack takes a solved type out of a unification variable data slot, and
// unpacks it over that unification variable to turn the type into a normal one.
// This returns false if no action was taken, or true if it modified the input.
// This helper function only looks at the direct unification variable. This is
// harmless, because it would only unpack over a nested unification variable if
// this types data field was non-nil, however the recursive version named Unpack
// is probably what you actually want to call.
func unpack(typ *types.Type) bool {
if typ.Uni == nil {
return false
}
// Remember, the types of every Elem in the set are known to be equal to
// each other, so there is only one .Data for the whole set, not one per
// Elem. Some Union-Find libraries only allow you to store one piece of
// data per set, but our library allows each Elem to hold a different
// .Data value; that's not what we want, so we follow the convention
// that only the root of each set has meaningful .Data, the rest have
// junk, outdated .Data values.
root := typ.Uni.Find() // We should definitely call find.
//root := typ.Uni // wrong!
if root.Data == nil {
return false
}
*typ = *root.Data // squash it
return true
}
// UnifyCmp is like the standard *types.Type Cmp() except that it allows either
// type to include unification variables. It returns nil if the two types are
// consistent, identical, and don't contain any remaining unification variables.
// This function copies the input types, so this can be use safely without side
// effects from the underlying Unify operation. In the simple case where neither
// type contains unification variables, this is a more expensive version of
// typ1.Cmp(typ2). It is not necessarily valid at this time to call this if both
// sides contain unification variables. In that situation, this implementation
// will panic. It does not matter which side is the one that contains
// unification variables, as long as both of them don't. The textbook
// terminology is that we're checking that the "scrutinee" (the type which
// cannot contain unification variables) "matches" the "type pattern" (the type
// which can contain unification variables). Thanks to Sam for that information.
// TODO: It's possible that this could be useful with both sides containing
// unification variables, but keep this panic in until we find that case and add
// tests for it.
func UnifyCmp(typ1, typ2 *types.Type) error {
if typ1.HasUni() && typ2.HasUni() {
panic("both types have unification variables")
}
t1, t2 := typ1.Copy(), typ2.Copy()
if err := Unify(t1, t2); err != nil {
return err
}
e1, e2 := Extract(t1), Extract(t2)
if e1.HasUni() || e2.HasUni() {
return fmt.Errorf("ambiguous type cmp")
}
return e1.Cmp(e2)
}
// UnifyCopy copies a type, but preserves any unification variables (the .Uni)
// it encounters. It panics if it encounters an invalid or partial type struct.
func UnifyCopy(typ *types.Type) *types.Type {
ret := &types.Type{ // return at the end
Kind: typ.Kind,
}
switch typ.Kind {
case types.KindBool:
case types.KindStr:
case types.KindInt:
case types.KindFloat:
case types.KindList:
if typ.Val == nil {
panic("malformed list type")
}
ret.Val = UnifyCopy(typ.Val)
case types.KindMap:
if typ.Key == nil || typ.Val == nil {
panic("malformed map type")
}
ret.Key = UnifyCopy(typ.Key)
ret.Val = UnifyCopy(typ.Val)
case types.KindStruct: // {a bool; b int}
if typ.Map == nil {
panic("malformed struct type")
}
if len(typ.Map) != len(typ.Ord) {
panic("malformed struct length")
}
ret.Map = make(map[string]*types.Type)
for _, k := range typ.Ord {
t, ok := typ.Map[k]
if !ok {
panic("malformed struct order")
}
if t == nil {
panic("malformed struct field")
}
ret.Map[k] = UnifyCopy(t)
ret.Ord = append(ret.Ord, k)
}
case types.KindFunc:
if typ.Map == nil {
panic("malformed func type")
}
if len(typ.Map) != len(typ.Ord) {
panic("malformed func length")
}
ret.Map = make(map[string]*types.Type)
for _, k := range typ.Ord {
t, ok := typ.Map[k]
if !ok {
panic("malformed func order")
}
if t == nil {
panic("malformed func field")
}
ret.Map[k] = UnifyCopy(t)
ret.Ord = append(ret.Ord, k)
}
if typ.Out != nil {
ret.Out = UnifyCopy(typ.Out)
}
case types.KindVariant:
ret.Var = UnifyCopy(typ.Var)
case types.KindUnification:
if typ.Uni == nil {
panic("malformed unification variable")
}
ret.Uni = typ.Uni // don't copy! (we want the same pointer)
default:
panic("malformed type")
}
return ret
}