Генерация нескольких графиков с помощью Dash - PullRequest
0 голосов
/ 18 января 2019

Я новичок в Dash/Plot.ly, и в настоящее время я пытаюсь воспроизвести следующие графики (созданные с matplotlib) с Dash:

enter image description here

Моя попытка сделать это состояла в том, чтобы создать метод для генерации одной фигуры:

def serve_prediction_plot(model, title, X, X_proj, y, y_proc, train_idx, test_idx, Z, xx, yy, x0, y0, d):

# Get train and test score from model
train_score = cross_val_score(model, X[train_idx], y_proc[train_idx]).mean()
test_score = model.score(X[test_idx], y_proc[test_idx])

# Colorscale
bright_cscale = [[0, '#FF0000'], [1, '#0000FF']]
colorscale_zip = zip(np.arange(0, 1.01, 1 / 8), cl.scales['9']['div']['RdBu'])
cscale = list(map(list, colorscale_zip))

axis_template = dict(
    showgrid=False,
    zeroline=False,
    linecolor='white',
    showticklabels=False,
    ticks=''
)

layout = dict(
    title=title,
    xaxis=axis_template,
    yaxis=axis_template,
    showlegend=False,
    hovermode='closest',
    autosize=False,
    margin=dict(l=0, r=0, t=30, b=0)
)

# Plot the prediction contour of the models
Z = Z.reshape(xx.shape)
print(Z.shape)
trace0 = go.Heatmap(
    z=Z,
    hoverinfo='none',
    showscale=False,
    colorscale=cscale,
    x0=x0,
    y0=y0,
    dx=d,
    dy=d
)

# Plot Training Data
trace1 = go.Scatter(
    x=X_proj[train_idx, 0],
    y=X_proj[train_idx, 1],
    mode='markers',
    name='Training Data (accuracy={:.3f})'.format(train_score),
    text=y[train_idx],
    marker=dict(
        size=10,
        color=y_proc[train_idx],
        colorscale=bright_cscale,
        line=dict(
            width=1
        )
    )
)

# Plot Test Data
trace2 = go.Scatter(
    x=X_proj[test_idx, 0],
    y=X_proj[test_idx, 1],
    mode='markers',
    name='Test Data (accuracy={:.3f})'.format(train_score),
    text=y[test_idx],
    marker=dict(
        size=10,
        symbol='triangle-up',
        color=y_proc[test_idx],
        colorscale=bright_cscale,
        line=dict(
            width=1
        ),
    )
)

data = [trace0, trace1, trace2]
figure = go.Figure(data=data, layout=layout)

return figure

Что является вызовом при построении вида Dash:

def generate_dense_maps():

return html.Div(
    className='row',
    style={
        'margin-top': '5px',

        # Remove possibility to select the text for better UX
        'user-select': 'none',
        '-moz-user-select': 'none',
        '-webkit-user-select': 'none',
        '-ms-user-select': 'none'
    },
    children=[
        html.Div(
            [
                dcc.Graph(
                    id='graph-{name}'.format(name=clf_name),
                    figure=serve_prediction_plot(clf,
                                                 clf_name,
                                                 service.dataset.X,
                                                 service.dataset.X_proj,
                                                 service.dataset.y,
                                                 service.dataset.y_proc,
                                                 service.dataset.train_idx,
                                                 service.dataset.test_idx,
                                                 service.get_prediction(clf),
                                                 service.grid.xx,
                                                 service.grid.yy,
                                                 service.x_min,
                                                 service.y_min,
                                                 service.grid.h),
                )
            ],
            className="two columns"
        ) for clf_name, clf in service.classifiers.items()
    ]
)


# -------------------- Dash --------------------
app = dash.Dash(__name__)

app.layout = html.Div(children=[
    # -------------------- Title Bar --------------------
    html.Div(className="banner", children=[
        html.Div(className='container scalable', children=[
            html.H2(html.A(
                'Title goes here',
                style={
                    'text-decoration': 'none',
                    'color': 'inherit'
                }
            )),

            html.A(
                html.Img(src="https://s3-us-west-1.amazonaws.com/plotly-tutorials/logo/new-branding/dash-logo-by-plotly-stripe-inverted.png"),
                href='https://plot.ly/products/dash/'
            )
        ]),
    ]),

    # -------------------- Body -------------------------
    html.Div(id='body', className='container scalable', children=[
        html.Div(className='row', children=[
            # -------------------- Classifiers ------------------
            html.Div(
                id='div-classifiers', children=[
                    html.H4(html.A(
                        'Classifiers',
                        style={
                            'text-decoration': 'none',
                            'color': 'inherit'
                        }
                    )),
                    generate_dense_maps()
                ]
            ),
            # -------------------- Uncertainty ------------------
            html.Div(
                id='div=uncertainty'
            )
        ])
    ])

])

Однако изображения обрезаются:

enter image description here

Интересно, что мне не хватает или как правильно добиться желаемого результата. Я также попытался нарисовать что-то вроде этого (что, я думаю, выглядело бы лучше в сети):

image1 | image2
image3 | image4
image5 | image6

Без удачи.

МИНИМАЛЬНЫЙ ПРИМЕР

import numpy as np

from sklearn.cross_validation import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_moons, make_circles, make_classification
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier

names = ["Nearest Neighbors", "Linear SVM", "RBF SVM", "Decision Tree",
         "Random Forest"]
classifiers = [
    KNeighborsClassifier(3),
    SVC(kernel="linear", C=0.025, probability=True),
    SVC(gamma=2, C=1, probability=True),
    DecisionTreeClassifier(max_depth=5),
    RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1),
]

X, y = make_classification(n_features=2, n_redundant=0, n_informative=2,
                           random_state=1, n_clusters_per_class=1)
rng = np.random.RandomState(2)
X += 2 * rng.uniform(size=X.shape)

h = .02
X = StandardScaler().fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.4)

x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                     np.arange(y_min, y_max, h))

import plotly.graph_objs as go
import colorlover as cl

from sklearn.model_selection import cross_val_score

def serve_prediction_plot(model, title, X_train, y_train, X_test, y_test, xx, yy, d):
    # Get train and test score from model
    model.fit(X_train, y_train)
    train_score = cross_val_score(model, X_train, y_train).mean()
    test_score = model.score(X_test, y_test)

    # Colorscale
    bright_cscale = [[0, '#FF0000'], [1, '#0000FF']]
    colorscale_zip = zip(np.arange(0, 1.01, 1 / 8), cl.scales['9']['div']['RdBu'])
    cscale = list(map(list, colorscale_zip))

    axis_template = dict(
        showgrid=False,
        zeroline=False,
        linecolor='white',
        showticklabels=False,
        ticks=''
    )

    layout = dict(
        title=title,
        xaxis=axis_template,
        yaxis=axis_template,
        showlegend=False,
        hovermode='closest',
        autosize=False,
        margin=dict(l=0, r=0, t=30, b=0)
    )

    # Plot the prediction contour of the models
    try:
        Z = model.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]
    except NotImplementedError:
        Z = model.decision_function(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    print(Z.shape)
    trace0 = go.Heatmap(
        z=Z,
        hoverinfo='none',
        showscale=False,
        colorscale=cscale,
        x0=xx.min(),
        y0=yy.min(),
        dx=d,
        dy=d
    )

    # Plot Training Data
    trace1 = go.Scatter(
        x=X_train[:, 0],
        y=X_train[:, 1],
        mode='markers',
        name='Training Data (accuracy={:.3f})'.format(train_score),
        text=y_train,
        marker=dict(
            size=10,
            color=y_train,
            colorscale=bright_cscale,
            line=dict(
                width=1
            )
        )
    )

    # Plot Test Data
    trace2 = go.Scatter(
        x=X_test[:, 0],
        y=X_test[:, 1],
        mode='markers',
        name='Test Data (accuracy={:.3f})'.format(train_score),
        text=y_test,
        marker=dict(
            size=10,
            symbol='triangle-up',
            color=y_test,
            colorscale=bright_cscale,
            line=dict(
                width=1
            ),
        )
    )

    data = [trace0, trace1, trace2]
    figure = go.Figure(data=data, layout=layout)

    return figure

import dash
import dash_core_components as dcc
import dash_html_components as html

from dash.dependencies import Input, Output, State

def generate_dense_maps():

    return html.Div(
        className='row',
        style={
            'margin-top': '5px',

            # Remove possibility to select the text for better UX
            'user-select': 'none',
            '-moz-user-select': 'none',
            '-webkit-user-select': 'none',
            '-ms-user-select': 'none'
        },
        children=[
            html.Div(
                [
                    dcc.Graph(
                        id='graph-{name}'.format(name=clf_name),
                        figure=serve_prediction_plot(clf, clf_name, X_train, y_train, X_test, y_test, xx, yy, h),
                    )
                ],
                className="two columns"
            ) for clf_name, clf in zip(names, classifiers)
        ]
    )


# -------------------- Dash --------------------
app = dash.Dash(__name__)

app.layout = html.Div(children=[
    # -------------------- Title Bar --------------------
    html.Div(className="banner", children=[
        html.Div(className='container scalable', children=[
            html.H2(html.A(
                'Title goes here',
                style={
                    'text-decoration': 'none',
                    'color': 'inherit'
                }
            )),

            html.A(
                html.Img(src="https://s3-us-west-1.amazonaws.com/plotly-tutorials/logo/new-branding/dash-logo-by-plotly-stripe-inverted.png"),
                href='https://plot.ly/products/dash/'
            )
        ]),
    ]),

    # -------------------- Body -------------------------
    html.Div(id='body', className='container scalable', children=[
        html.Div(className='row', children=[
            # -------------------- Classifiers ------------------
            html.Div(
                id='div-classifiers', children=[
                    html.H4(html.A(
                        'Classifiers',
                        style={
                            'text-decoration': 'none',
                            'color': 'inherit'
                        }
                    )),
                    generate_dense_maps()
            ])
        ])
    ])

])

external_css = [
    # Normalize the CSS
    "https://cdnjs.cloudflare.com/ajax/libs/normalize/7.0.0/normalize.min.css",
    # Fonts
    "https://fonts.googleapis.com/css?family=Open+Sans|Roboto",
    "https://maxcdn.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css",
    # Base Stylesheet, replace this with your own base-styles.css using Rawgit
    "https://rawgit.com/xhlulu/9a6e89f418ee40d02b637a429a876aa9/raw/f3ea10d53e33ece67eb681025cedc83870c9938d/base-styles.css",
    # Custom Stylesheet, replace this with your own custom-styles.css using Rawgit
    "https://cdn.rawgit.com/plotly/dash-svm/bb031580/custom-styles.css"
]

for css in external_css:
    app.css.append_css({"external_url": css})

app.run_server(debug=True)
...