# -*- coding: utf-8 -*-
"""Useful functions for plotting
"""
# %% IMPORTS
import numpy as np
import matplotlib.pyplot as plt
# %% COLLECTION FOR 3D ARROW PLOTTING
# source : https://gist.github.com/WetHat/1d6cd0f7309535311a539b42cccca89c
from matplotlib.patches import FancyArrowPatch
from matplotlib.text import Annotation
from mpl_toolkits.mplot3d.proj3d import proj_transform
from mpl_toolkits.mplot3d.axes3d import Axes3D
[docs]
class Arrow3D(FancyArrowPatch):
def __init__(self, x, y, z, dx, dy, dz, *args, **kwargs):
super().__init__((0, 0), (0, 0), *args, **kwargs)
self._xyz = (x, y, z)
self._dxdydz = (dx, dy, dz)
[docs]
def draw(self, renderer):
x1, y1, z1 = self._xyz
dx, dy, dz = self._dxdydz
x2, y2, z2 = (x1 + dx, y1 + dy, z1 + dz)
xs, ys, zs = proj_transform((x1, x2), (y1, y2), (z1, z2), self.axes.M)
self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
super().draw(renderer)
def do_3d_projection(self, renderer=None):
x1, y1, z1 = self._xyz
dx, dy, dz = self._dxdydz
x2, y2, z2 = (x1 + dx, y1 + dy, z1 + dz)
xs, ys, zs = proj_transform((x1, x2), (y1, y2), (z1, z2), self.axes.M)
self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
return np.min(zs)
[docs]
class Annotation3D(Annotation):
def __init__(self, text, xyz, *args, **kwargs):
super().__init__(text, xy=(0, 0), *args, **kwargs)
self._xyz = xyz
[docs]
def draw(self, renderer):
x2, y2, z2 = proj_transform(*self._xyz, self.axes.M)
self.xy = (x2, y2)
super().draw(renderer)
def _arrow3D(ax, x, y, z, dx, dy, dz, *args, **kwargs):
"""Add an 3d arrow to an `Axes3D` instance."""
arrow = Arrow3D(x, y, z, dx, dy, dz, *args, **kwargs)
ax.add_artist(arrow)
def _annotate3D(ax, text, xyz, *args, **kwargs):
"""Add anotation `text` to an `Axes3d` instance."""
annotation = Annotation3D(text, xyz, *args, **kwargs)
ax.add_artist(annotation)
setattr(Axes3D, "arrow3D", _arrow3D)
setattr(Axes3D, "annotate3D", _annotate3D)