diff --git a/pgraph/pgraph.go b/pgraph/pgraph.go index 3c56040c..d981f74e 100644 --- a/pgraph/pgraph.go +++ b/pgraph/pgraph.go @@ -520,6 +520,34 @@ func (g *Graph) FilterGraph(vertices []Vertex) (*Graph, error) { return newGraph, nil } +// FilterGraphWithFn builds a new graph containing only vertices which match. It +// uses a user defined function to match. That function must return true on +// match, and an error if anything goes wrong. +func (g *Graph) FilterGraphWithFn(fn func(Vertex) (bool, error)) (*Graph, error) { + newGraph, err := NewGraph(g.Name) + if err != nil { + return nil, err + } + for k1, x := range g.adjacency { + contains, err := fn(k1) + if err != nil { + return nil, errwrap.Wrapf(err, "fn in FilterGraphWithFn() errored") + } else if contains { + newGraph.AddVertex(k1) + } + for k2, e := range x { + innerContains, err := fn(k2) + if err != nil { + return nil, errwrap.Wrapf(err, "fn in FilterGraphWithFn() errored") + } + if contains && innerContains { + newGraph.AddEdge(k1, k2, e) + } + } + } + return newGraph, nil +} + // DisconnectedGraphs returns a list containing the N disconnected graphs. func (g *Graph) DisconnectedGraphs() ([]*Graph, error) { graphs := []*Graph{}