import matplotlib.pyplot as plt
import numpy as np
from MCEq.geometry.geometry import *

earth = EarthGeometry()

theta_list = np.linspace(0, 90, 500)
h_vec = np.linspace(0, earth.h_atm, 500)
th_list_rad = np.deg2rad(theta_list)
fig = plt.figure(figsize=(5, 4))
fig.set_tight_layout(dict(rect=[0.00, 0.00, 1, 1]))
plt.plot(theta_list, earth.path_len(th_list_rad) / 1e5, lw=2)
plt.xlabel(r"zenith $\theta$ at detector")
plt.ylabel(r"path length $l(\theta)$ in km")
ax = plt.gca()
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.xaxis.set_ticks_position("bottom")
ax.yaxis.set_ticks_position("left")

fig = plt.figure(figsize=(5, 4))
fig.set_tight_layout(dict(rect=[0.00, 0.00, 1, 1]))
plt.plot(
    theta_list, np.arccos(earth.cos_th_star(th_list_rad)) / np.pi * 180.0, lw=2
)
plt.xlabel(r"zenith $\theta$ at detector")
plt.ylabel(r"$\theta^*$ at top of the atm.")
plt.ylim([0, 90])
ax = plt.gca()
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.xaxis.set_ticks_position("bottom")
ax.yaxis.set_ticks_position("left")

fig = plt.figure(figsize=(5, 4))
fig.set_tight_layout(dict(rect=[0.00, 0.00, 1, 1]))
plt.plot(h_vec / 1e5, earth.delta_l(h_vec, np.deg2rad(85.0)) / 1e5, lw=2)
plt.ylabel(r"Path length $\Delta l(h)$ in km")
plt.xlabel(r"atm. height $h_{atm}$ in km")
ax = plt.gca()
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.xaxis.set_ticks_position("bottom")
ax.yaxis.set_ticks_position("left")

fig = plt.figure(figsize=(5, 4))
fig.set_tight_layout(dict(rect=[0.00, 0.00, 1, 1]))
for theta in [30.0, 60.0, 70.0, 80.0, 85.0, 90.0]:
    theta_path = np.deg2rad(theta)
    delta_l_vec = np.linspace(0, earth.path_len(theta_path), 1000)
    plt.plot(
        delta_l_vec / 1e5,
        earth.h(delta_l_vec, theta_path) / 1e5,
        label=rf"${theta}^o$",
        lw=2,
    )
plt.legend()
plt.xlabel(r"path length $\Delta l$ [km]")
plt.ylabel(r"atm. height $h_{atm}(\Delta l, \theta)$ [km]")
ax = plt.gca()
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.xaxis.set_ticks_position("bottom")
ax.yaxis.set_ticks_position("left")
plt.show()