How to plot a venn3 in plotly

4.7k views Asked by At

I have found the following code

import math
from matplotlib import pyplot as plt
import matplotlib
from matplotlib_venn import venn2, venn3
import numpy as np
from matplotlib.figure import Figure
import plotly

# Define some sets
    a = set(['a', 'b', 'c']) 
    b = set(['c', 'd', 'e'])
    c = set(['e', 'f', 'a'])
    s = [a, b, c]

    # Plot it
    matplotlib.pyplot.switch_backend('Agg')
    h = venn3(s, ('A', 'B', 'C'))


    fig=plotly.graph_objs.Figure()

    graph_div = plotly.offline.plot(fig, auto_open = False, output_type="div")
    

What I don't know how to do is plot the h in the figure. I'm trying to implement a venn2 and venn3 diagram of keyword searches on my website

1

There are 1 answers

1
Davide_sd On BEST ANSWER

This is the first result of a web search:

import matplotlib.pyplot as plt
from matplotlib_venn import venn2, venn3

import plotly as py
from plotly.offline import iplot
import plotly.graph_objs as go
import plotly.io as pio
pio.renderers.default = 'iframe'

import scipy


def venn_to_plotly(L_sets,L_labels=None,title=None):
    
    #get number of sets
    n_sets = len(L_sets)
    
    #choose and create matplotlib venn diagramm
    if n_sets == 2:
        if L_labels and len(L_labels) == n_sets:
            v = venn2(L_sets,L_labels)
        else:
            v = venn2(L_sets)
    elif n_sets == 3:
        if L_labels and len(L_labels) == n_sets:
            v = venn3(L_sets,L_labels)
        else:
            v = venn3(L_sets)
    #supress output of venn diagramm
    plt.close()
    
    #Create empty lists to hold shapes and annotations
    L_shapes = []
    L_annotation = []
    
    #Define color list for sets
    #check for other colors: https://css-tricks.com/snippets/css/named-colors-and-hex-equivalents/
    L_color = ['FireBrick','DodgerBlue','DimGrey'] 
    
    #Create empty list to make hold of min and max values of set shapes
    L_x_max = []
    L_y_max = []
    L_x_min = []
    L_y_min = []
    
    for i in range(0,n_sets):
        
        #create circle shape for current set
        
        shape = go.layout.Shape(
                type="circle",
                xref="x",
                yref="y",
                x0= v.centers[i][0] - v.radii[i],
                y0=v.centers[i][1] - v.radii[i],
                x1= v.centers[i][0] + v.radii[i],
                y1= v.centers[i][1] + v.radii[i],
                fillcolor=L_color[i],
                line_color=L_color[i],
                opacity = 0.75
            )
        
        L_shapes.append(shape)
        
        #create set label for current set
        
        anno_set_label = go.layout.Annotation(
                xref="x",
                yref="y",
                x = v.set_labels[i].get_position()[0],
                y = v.set_labels[i].get_position()[1],
                text = v.set_labels[i].get_text(),
                showarrow=False
        )
        
        L_annotation.append(anno_set_label)
        
        #get min and max values of current set shape
        L_x_max.append(v.centers[i][0] + v.radii[i])
        L_x_min.append(v.centers[i][0] - v.radii[i])
        L_y_max.append(v.centers[i][1] + v.radii[i])
        L_y_min.append(v.centers[i][1] - v.radii[i])
    
    #determine number of subsets
    n_subsets = sum([scipy.special.binom(n_sets,i+1) for i in range(0,n_sets)])
    
    for i in range(0,int(n_subsets)):
        
        #create subset label (number of common elements for current subset
        
        anno_subset_label = go.layout.Annotation(
                xref="x",
                yref="y",
                x = v.subset_labels[i].get_position()[0],
                y = v.subset_labels[i].get_position()[1],
                text = v.subset_labels[i].get_text(),
                showarrow=False
        )
        
        L_annotation.append(anno_subset_label)
        
        
    #define off_set for the figure range    
    off_set = 0.2
    
    #get min and max for x and y dimension to set the figure range
    x_max = max(L_x_max) + off_set
    x_min = min(L_x_min) - off_set
    y_max = max(L_y_max) + off_set
    y_min = min(L_y_min) - off_set
    
    #create plotly figure
    p_fig = go.Figure()
    
    #set xaxes range and hide ticks and ticklabels
    p_fig.update_xaxes(
        range=[x_min, x_max], 
        showticklabels=False, 
        ticklen=0
    )
    
    #set yaxes range and hide ticks and ticklabels
    p_fig.update_yaxes(
        range=[y_min, y_max], 
        scaleanchor="x", 
        scaleratio=1, 
        showticklabels=False, 
        ticklen=0
    )
    
    #set figure properties and add shapes and annotations
    p_fig.update_layout(
        plot_bgcolor='white', 
        margin = dict(b = 0, l = 10, pad = 0, r = 10, t = 40),
        width=800, 
        height=400,
        shapes= L_shapes, 
        annotations = L_annotation,
        title = dict(text = title, x=0.5, xanchor = 'center')
    )

    p_fig.show()

Than, using that function with your data:

venn_to_plotly(s, ('A', 'B', 'C'))

enter image description here