diff --git a/caravel/viz.py b/caravel/viz.py index d81066fd1..e73c05176 100644 --- a/caravel/viz.py +++ b/caravel/viz.py @@ -1043,7 +1043,32 @@ class SankeyViz(BaseViz): def get_data(self): df = self.get_df() df.columns = ['source', 'target', 'value'] - return df.to_dict(orient='records') + recs = df.to_dict(orient='records') + + hierarchy = defaultdict(set) + for row in recs: + hierarchy[row['source']].add(row['target']) + + def find_cycle(g): + """Whether there's a cycle in a directed graph""" + path = set() + def visit(vertex): + path.add(vertex) + for neighbour in g.get(vertex, ()): + if neighbour in path or visit(neighbour): + return (vertex, neighbour) + path.remove(vertex) + for v in g: + cycle = visit(v) + if cycle: + return cycle + + cycle = find_cycle(hierarchy) + if cycle: + raise Exception( + "There's a loop in your Sankey, please provide a tree. " + "Here's a faulty link: {}".format(cycle)) + return recs class DirectedForceViz(BaseViz):