import matplotlib.pyplot as plt import numpy as np import os import time def smooth_curve(points, factor=0.9): smoothed_points = [] if points: smoothed_points.append(points[0]) for i in range(1, len(points)): smoothed_points.append(smoothed_points[-1] * factor + points[i] * (1 - factor)) return smoothed_points def plot_rewards(rewards_history, log_interval, save_dir, filename="Trained_Agent_rewards_plot.png", show_plot=True): os.makedirs(save_dir, exist_ok=True) plt.figure(figsize=(12, 6)) episodes = [i * log_interval for i in range(1, len(rewards_history) + 1)] plt.plot(episodes, rewards_history, label='Average Reward') plt.xlabel('Episodes') plt.ylabel('Average Reward') plt.title('Trained Agent Live Reward (Average Reward per Episode)') plt.grid(True) plt.legend() plt.tight_layout() save_path = os.path.join(save_dir, filename) plt.savefig(save_path) print(f"Plot saved to: {os.path.abspath(save_path)}") if show_plot: plt.show() def init_live_plot(save_dir, filename="live_rewards_plot.png"): plt.ion() fig, ax = plt.subplots(figsize=(12, 6)) line, = ax.plot([], [], label='Smoothed Average Reward') ax.set_xlabel('Episodes') ax.set_ylabel('Average Reward') ax.set_title('Live Reward for Trained Agent') ax.grid(True) ax.legend() plt.tight_layout() ax._save_path_final = os.path.join(save_dir, filename) return fig, ax, line def update_live_plot(fig, ax, line, episodes, smoothed_rewards, current_timestep=None, total_timesteps=None): if not episodes or not smoothed_rewards: return line.set_data(episodes, smoothed_rewards) ax.set_xlim(0, max(episodes) * 1.05 if episodes else 1) min_y = min(smoothed_rewards) * 0.9 if smoothed_rewards else -1 max_y = max(smoothed_rewards) * 1.1 if smoothed_rewards else 1 if abs(max_y - min_y) < 0.1: min_y -= 0.05 max_y += 0.05 ax.set_ylim(min_y, max_y) if current_timestep is not None and total_timesteps is not None: ax.set_title(f'Live Agent Progress (Timestep: {current_timestep:,}/{total_timesteps:,})') fig.canvas.draw() fig.canvas.flush_events() time.sleep(0.01) def save_live_plot_final(fig, ax): plt.ioff() save_path = getattr(ax, '_save_path_final', None) if save_path: plt.savefig(save_path) print(f"Final live plot saved to: {os.path.abspath(save_path)}") plt.close(fig) plt.show()