#ifndef PSK_MODULATOR_H
#define PSK_MODULATOR_H

#include <vector>
#include <cmath>
#include <cstdint>
#include <stdexcept>
#include <complex>
#include <algorithm>

class PSKModulator {
public:
    PSKModulator(double baud_rate, double sample_rate, double energy_per_bit, bool is_frequency_hopping) 
        : sample_rate(sample_rate), carrier_freq(1800), phase(0.0) {
        initializeSymbolMap();
        symbol_rate = 2400;  // Fixed symbol rate as per specification (2400 symbols per second)
        samples_per_symbol = static_cast<size_t>(sample_rate / symbol_rate);
    }

    std::vector<int16_t> modulate(const std::vector<uint8_t>& symbols) {
        std::vector<std::complex<double>> modulated_signal;

        const double phase_increment = 2 * M_PI * carrier_freq / sample_rate;
        for (auto symbol : symbols) {
            if (symbol >= symbolMap.size()) {
                throw std::out_of_range("Invalid symbol value for 8-PSK modulation");
            }
            std::complex<double> target_symbol = symbolMap[symbol];

            for (size_t i = 0; i < samples_per_symbol; ++i) {
                double in_phase = std::cos(phase + target_symbol.real());
                double quadrature = std::sin(phase + target_symbol.imag());
                modulated_signal.emplace_back(in_phase, quadrature);
                phase = std::fmod(phase + phase_increment, 2 * M_PI);
            }
        }

        // Apply raised-cosine filter
        auto filter_taps = sqrtRaisedCosineFilter(201, symbol_rate);  // Adjusted number of filter taps to 201 for balance
        auto filtered_signal = applyFilter(modulated_signal, filter_taps);

        // Normalize the filtered signal
        double max_value = 0.0;
        for (const auto& sample : filtered_signal) {
            max_value = std::max(max_value, std::abs(sample.real()));
            max_value = std::max(max_value, std::abs(sample.imag()));
        }
        double gain = (max_value > 0) ? (32767.0 / max_value) : 1.0;

        // Combine the I and Q components and apply gain for audio output
        std::vector<int16_t> combined_signal;
        for (auto& sample : filtered_signal) {
            int16_t combined_sample = static_cast<int16_t>(std::clamp(gain * (sample.real() + sample.imag()), -32768.0, 32767.0));
            combined_signal.push_back(combined_sample);
        }

        return combined_signal;
    }

    std::vector<double> sqrtRaisedCosineFilter(size_t num_taps, double symbol_rate) {
        double rolloff = 0.35;  // Fixed rolloff factor as per specification
        std::vector<double> filter_taps(num_taps);
        double norm_factor = 0.0;
        double sampling_interval = 1.0 / sample_rate;
        double symbol_duration = 1.0 / symbol_rate;
        double half_num_taps = static_cast<double>(num_taps - 1) / 2.0;

        for (size_t i = 0; i < num_taps; ++i) {
            double t = (i - half_num_taps) * sampling_interval;
            if (std::abs(t) < 1e-10) {
                filter_taps[i] = 1.0;
            } else {
                double numerator = std::sin(M_PI * t / symbol_duration * (1.0 - rolloff)) +
                                  4.0 * rolloff * t / symbol_duration * std::cos(M_PI * t / symbol_duration * (1.0 + rolloff));
                double denominator = M_PI * t * (1.0 - std::pow(4.0 * rolloff * t / symbol_duration, 2));
                filter_taps[i] = numerator / denominator;
            }
            norm_factor += filter_taps[i] * filter_taps[i];
        }

        norm_factor = std::sqrt(norm_factor);
        std::for_each(filter_taps.begin(), filter_taps.end(), [&norm_factor](double &tap) { tap /= norm_factor; });
        return filter_taps;
    }

    std::vector<std::complex<double>> applyFilter(const std::vector<std::complex<double>>& signal, const std::vector<double>& filter_taps) {
        std::vector<std::complex<double>> filtered_signal(signal.size());

        size_t filter_length = filter_taps.size();
        size_t half_filter_length = filter_length / 2;

        // Convolve the signal with the filter taps
        for (size_t i = 0; i < signal.size(); ++i) {
            double filtered_i = 0.0;
            double filtered_q = 0.0;

            for (size_t j = 0; j < filter_length; ++j) {
                if (i >= j) {
                    filtered_i += filter_taps[j] * signal[i - j].real();
                    filtered_q += filter_taps[j] * signal[i - j].imag();
                } else {
                    // Handle edge case by zero-padding
                    filtered_i += filter_taps[j] * 0.0;
                    filtered_q += filter_taps[j] * 0.0;
                }
            }

            filtered_signal[i] = std::complex<double>(filtered_i, filtered_q);
        }

        return filtered_signal;
    }

private:
    double sample_rate;        ///< The sample rate of the system.
    double carrier_freq;       ///< The frequency of the carrier, set to 1800 Hz as per standard.
    double phase;              ///< Current phase of the carrier waveform.
    size_t samples_per_symbol; ///< Number of samples per symbol, calculated to match symbol duration with cycle.
    size_t symbol_rate;
    std::vector<std::complex<double>> symbolMap;  ///< The mapping of tribit symbols to I/Q components.

    void initializeSymbolMap() {
        symbolMap = {
            {1.0, 0.0},                              // 0 (000) corresponds to I = 1.0, Q = 0.0
            {std::sqrt(2.0) / 2.0, std::sqrt(2.0) / 2.0}, // 1 (001) corresponds to I = cos(45), Q = sin(45)
            {0.0, 1.0},                              // 2 (010) corresponds to I = 0.0, Q = 1.0
            {-std::sqrt(2.0) / 2.0, std::sqrt(2.0) / 2.0}, // 3 (011) corresponds to I = cos(135), Q = sin(135)
            {-1.0, 0.0},                             // 4 (100) corresponds to I = -1.0, Q = 0.0
            {-std::sqrt(2.0) / 2.0, -std::sqrt(2.0) / 2.0}, // 5 (101) corresponds to I = cos(225), Q = sin(225)
            {0.0, -1.0},                             // 6 (110) corresponds to I = 0.0, Q = -1.0
            {std::sqrt(2.0) / 2.0, -std::sqrt(2.0) / 2.0}  // 7 (111) corresponds to I = cos(315), Q = sin(315)
        };
    }
};

#endif