import matplotlib.pyplot as plt
import numpy as np
from MCEq.geometry import *
from MCEq.misc import theta_rad

g = EarthGeometry()
theta_list = np.linspace(0, 90, 500)
h_vec = np.linspace(0, g.h_atm, 500)
th_list_rad = theta_rad(theta_list)
fig = plt.figure(figsize=(5, 4))
fig.set_tight_layout(dict(rect=[0.00, 0.00, 1, 1]))
plt.plot(theta_list, g.l(th_list_rad) / 1e5,
         lw=2)
plt.xlabel(r'zenith $\theta$ at detector')
plt.ylabel(r'path length $l(\theta)$ in km')
ax = plt.gca()
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')

fig = plt.figure(figsize=(5, 4))
fig.set_tight_layout(dict(rect=[0.00, 0.00, 1, 1]))
plt.plot(theta_list, np.arccos(g.cos_th_star(th_list_rad)) / np.pi * 180.,
         lw=2)
plt.xlabel(r'zenith $\theta$ at detector')
plt.ylabel(r'$\theta^*$ at top of the atm.')
plt.ylim([0, 90])
ax = plt.gca()
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')

fig = plt.figure(figsize=(5, 4))
fig.set_tight_layout(dict(rect=[0.00, 0.00, 1, 1]))
plt.plot(h_vec / 1e5, g.delta_l(h_vec, theta_rad(85.)) / 1e5,
         lw=2)
plt.ylabel(r'Path length $\Delta l(h)$ in km')
plt.xlabel(r'atm. height $h_{atm}$ in km')
ax = plt.gca()
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')

fig = plt.figure(figsize=(5, 4))
fig.set_tight_layout(dict(rect=[0.00, 0.00, 1, 1]))
for theta in [30., 60., 70., 80., 85., 90.]:
    theta_path = theta_rad(theta)
    delta_l_vec = np.linspace(0, g.l(theta_path), 1000)
    plt.plot(delta_l_vec / 1e5, g.h(delta_l_vec, theta_path) / 1e5,
             label=r'${0}^o$'.format(theta), lw=2)
plt.legend()
plt.xlabel(r'path length $\Delta l$ [km]')
plt.ylabel(r'atm. height $h_{atm}(\Delta l, \theta)$ [km]')
ax = plt.gca()
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
plt.show()