Source code for climb.tool.impl.data_suite.third_party.copulas.visualization

import pandas as pd

try:
    import matplotlib.pyplot as plt
except RuntimeError as e:
    if "Python is not installed as a framework." in e.message:
        import matplotlib

        matplotlib.use("PS")  # Avoid crash on macos
        import matplotlib.pyplot as plt


[docs] def scatter_3d(data, columns=None, fig=None, title=None, position=None): """Plot 3 dimensional data in a scatter plot.""" fig = fig or plt.figure() position = position or 111 ax = fig.add_subplot(position, projection="3d") ax.scatter(*(data[column] for column in columns or data.columns)) if title: ax.set_title(title) ax.title.set_position([0.5, 1.05]) return ax
[docs] def scatter_2d(data, columns=None, fig=None, title=None, position=None): """Plot 2 dimensional data in a scatter plot.""" fig = fig or plt.figure() position = position or 111 ax = fig.add_subplot(position) columns = columns or data.columns if len(columns) != 2: raise ValueError("Only 2 columns can be plotted") x, y = columns ax.scatter(data[x], data[y]) plt.xlabel(x) plt.ylabel(y) if title: ax.set_title(title) ax.title.set_position([0.5, 1.05]) return ax
[docs] def hist_1d(data, fig=None, title=None, position=None, bins=20, label=None): """Plot 1 dimensional data in a histogram.""" fig = fig or plt.figure() position = position or 111 ax = fig.add_subplot(position) ax.hist(data, density=True, bins=bins, alpha=0.8, label=label) if label: ax.legend() if title: ax.set_title(title) ax.title.set_position([0.5, 1.05]) return ax
[docs] def side_by_side(plotting_func, arrays): fig = plt.figure(figsize=(10, 4)) position_base = "1{}".format(len(arrays)) for index, (title, array) in enumerate(arrays.items()): position = int(position_base + str(index + 1)) plotting_func(array, fig=fig, title=title, position=position) plt.tight_layout()
[docs] def compare_3d(real, synth, columns=None, figsize=(10, 4)): columns = columns or real.columns fig = plt.figure(figsize=figsize) scatter_3d(real[columns], fig=fig, title="Real Data", position=121) scatter_3d(synth[columns], fig=fig, title="Synthetic Data", position=122) plt.tight_layout()
[docs] def compare_2d(real, synth, columns=None, figsize=None): x, y = columns or real.columns ax = real.plot.scatter(x, y, color="blue", alpha=0.5, figsize=figsize) ax = synth.plot.scatter(x, y, ax=ax, color="orange", alpha=0.5, figsize=figsize) ax.legend(["Real", "Synthetic"])
[docs] def compare_1d(real, synth, columns=None, figsize=None): if len(real.shape) == 1: real = pd.DataFrame({"": real}) synth = pd.DataFrame({"": synth}) columns = columns or real.columns num_cols = len(columns) fig_cols = min(2, num_cols) fig_rows = (num_cols // fig_cols) + 1 prefix = "{}{}".format(fig_rows, fig_cols) figsize = figsize or (5 * fig_cols, 3 * fig_rows) fig = plt.figure(figsize=figsize) for idx, column in enumerate(columns): position = prefix + str(idx + 1) hist_1d(real[column], fig=fig, position=position, title=column, label="Real") hist_1d(synth[column], fig=fig, position=position, title=column, label="Synthetic") plt.tight_layout()