diff --git a/lang/structs.go b/lang/structs.go index 3fccc088..a41e8bde 100644 --- a/lang/structs.go +++ b/lang/structs.go @@ -5583,6 +5583,18 @@ func (obj *ExprIf) Unify() ([]interfaces.Invariant, error) { } invariants = append(invariants, branchesInvar) + // the two branches must match the type of the whole expression + thenInvar := &unification.EqualityInvariant{ + Expr1: obj, + Expr2: obj.ThenBranch, + } + invariants = append(invariants, thenInvar) + elseInvar := &unification.EqualityInvariant{ + Expr1: obj, + Expr2: obj.ElseBranch, + } + invariants = append(invariants, elseInvar) + return invariants, nil }