#!/usr/bin/env python3
"""
SHIELD/ATLAS — Local Sensor Bridge
Connects hardware sensors (RTL-SDR + USB Microphone) to the SHIELD/ATLAS cloud platform.

Hardware Supported:
  - NooElec NESDR Smart v5 (RTL-SDR) — RF spectrum scanning
  - Fifine K669B (USB Microphone) — Acoustic detection
  - Any RTL2832U-based SDR dongle
  - Any USB microphone

Requirements:
  pip install pyrtlsdr numpy requests sounddevice

Usage:
  python sensor-bridge.py --server https://shield-atlas-production.up.railway.app
  python sensor-bridge.py --server https://shield-atlas-production.up.railway.app --mic-only
  python sensor-bridge.py --server https://shield-atlas-production.up.railway.app --sdr-only

Copyright 2026 Integrated Security Solutions (ISS) — SDVOSB
CAGE: 9VKK3 | UEI: C7YDV3P8EHL7
"""

import argparse
import json
import logging
import signal
import sys
import threading
import time
from datetime import datetime, timezone

import numpy as np
import requests

logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s] [%(levelname)s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)
log = logging.getLogger("SHIELD-SENSOR-BRIDGE")

RUNNING = True

def signal_handler(sig, frame):
    global RUNNING
    log.info("Shutdown signal received — stopping sensors...")
    RUNNING = False

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)


class AcousticSensor:
    def __init__(self, server_url, device_index=None, sample_rate=44100, fft_size=4096):
        self.server_url = server_url.rstrip("/")
        self.device_index = device_index
        self.sample_rate = sample_rate
        self.fft_size = fft_size
        self.classify_url = f"{self.server_url}/api/acoustic/classify"
        self.stream = None

    def start(self):
        import sounddevice as sd

        log.info(f"[ACOUSTIC] Initializing USB microphone (rate={self.sample_rate}, FFT={self.fft_size})")

        devices = sd.query_devices()
        if self.device_index is not None:
            dev = devices[self.device_index]
            log.info(f"[ACOUSTIC] Using device {self.device_index}: {dev['name']}")
        else:
            default_input = sd.default.device[0]
            dev = devices[default_input]
            log.info(f"[ACOUSTIC] Using default input device: {dev['name']}")

        log.info("[ACOUSTIC] Sensor ACTIVE — capturing audio and classifying every 2 seconds")

        buffer = []
        last_classify = time.time()

        def audio_callback(indata, frames, time_info, status):
            nonlocal buffer, last_classify
            if status:
                log.warning(f"[ACOUSTIC] Stream status: {status}")
            buffer.extend(indata[:, 0].tolist())

            now = time.time()
            if now - last_classify >= 2.0 and len(buffer) >= self.fft_size:
                samples = np.array(buffer[-self.fft_size:])
                buffer = buffer[-self.fft_size:]
                last_classify = now

                threading.Thread(
                    target=self._classify,
                    args=(samples,),
                    daemon=True
                ).start()

        self.stream = sd.InputStream(
            device=self.device_index,
            channels=1,
            samplerate=self.sample_rate,
            blocksize=1024,
            callback=audio_callback
        )
        self.stream.start()

    def _classify(self, samples):
        try:
            windowed = samples * np.hanning(len(samples))
            fft_data = np.abs(np.fft.rfft(windowed))
            freqs = np.fft.rfftfreq(len(samples), 1.0 / self.sample_rate)

            peak_idx = np.argmax(fft_data[1:]) + 1
            peak_freq = float(freqs[peak_idx])

            rms = float(np.sqrt(np.mean(samples ** 2)))
            rms_db = float(20 * np.log10(max(rms, 1e-10)))

            energy = fft_data[1:] ** 2
            total_energy = np.sum(energy)
            if total_energy > 0:
                spectral_centroid = np.sum(freqs[1:] * energy) / total_energy
                spectral_spread = np.sqrt(np.sum(((freqs[1:] - spectral_centroid) ** 2) * energy) / total_energy)
            else:
                spectral_centroid = 0
                spectral_spread = 0

            if spectral_spread < 200:
                spectral_shape = "IMPULSE"
            elif spectral_spread > 2000:
                spectral_shape = "BROADBAND"
            elif peak_freq > 5000:
                spectral_shape = "SUPERSONIC_CRACK"
            else:
                spectral_shape = "SUSTAINED_ROAR"

            payload = {
                "peakFrequencyHz": round(peak_freq, 1),
                "rmsDb": round(rms_db, 1),
                "spectralShape": spectral_shape,
                "spectralCentroid": round(spectral_centroid, 1),
                "spectralSpread": round(spectral_spread, 1),
                "timestamp": datetime.now(timezone.utc).isoformat(),
                "sensorType": "USB_MICROPHONE",
                "sensorModel": "LOCAL_BRIDGE"
            }

            resp = requests.post(self.classify_url, json=payload, timeout=5)
            result = resp.json()

            if result.get("matched"):
                top = result.get("topMatch", {})
                log.warning(
                    f"[ACOUSTIC] MATCH: {top.get('name', 'Unknown')} "
                    f"(confidence: {top.get('confidence', 0):.0%}) "
                    f"peak={peak_freq:.0f}Hz rms={rms_db:.1f}dB shape={spectral_shape}"
                )
            else:
                log.info(
                    f"[ACOUSTIC] No match — peak={peak_freq:.0f}Hz "
                    f"rms={rms_db:.1f}dB shape={spectral_shape}"
                )

        except requests.exceptions.RequestException as e:
            log.error(f"[ACOUSTIC] Server connection failed: {e}")
        except Exception as e:
            log.error(f"[ACOUSTIC] Classification error: {e}")

    def stop(self):
        if self.stream:
            self.stream.stop()
            self.stream.close()
            log.info("[ACOUSTIC] Microphone stream stopped")


class SDRSensor:
    def __init__(self, server_url, center_freq=2.4e9, sample_rate=2.4e6, gain="auto"):
        self.server_url = server_url.rstrip("/")
        self.center_freq = center_freq
        self.sample_rate = sample_rate
        self.gain = gain
        self.classify_url = f"{self.server_url}/api/rf/classify"
        self.sdr = None

    def start(self):
        from rtlsdr import RtlSdr

        log.info(f"[SDR] Initializing RTL-SDR (center={self.center_freq/1e6:.1f}MHz, rate={self.sample_rate/1e6:.1f}MS/s)")

        self.sdr = RtlSdr()
        self.sdr.sample_rate = self.sample_rate
        self.sdr.center_freq = self.center_freq
        if self.gain == "auto":
            self.sdr.gain = "auto"
        else:
            self.sdr.gain = float(self.gain)

        log.info(f"[SDR] RTL-SDR ACTIVE — scanning {self.center_freq/1e6:.1f}MHz band")
        log.info(f"[SDR] Gain: {self.sdr.gain}, Sample rate: {self.sdr.sample_rate/1e6:.1f} MS/s")

        while RUNNING:
            try:
                samples = self.sdr.read_samples(256 * 1024)

                fft_data = np.abs(np.fft.fftshift(np.fft.fft(samples)))
                fft_db = 20 * np.log10(fft_data + 1e-10)
                freqs = np.fft.fftshift(
                    np.fft.fftfreq(len(samples), 1.0 / self.sample_rate)
                ) + self.center_freq

                noise_floor = np.median(fft_db)
                threshold = noise_floor + 15

                peak_indices = np.where(fft_db > threshold)[0]

                if len(peak_indices) > 0:
                    signals = self._extract_signals(freqs, fft_db, peak_indices, noise_floor)

                    for sig in signals:
                        self._classify_signal(sig)
                else:
                    log.debug(f"[SDR] No signals above threshold ({threshold:.1f}dB, floor={noise_floor:.1f}dB)")

                time.sleep(1)

            except Exception as e:
                log.error(f"[SDR] Read error: {e}")
                time.sleep(2)

    def _extract_signals(self, freqs, fft_db, peak_indices, noise_floor):
        signals = []
        groups = []
        current_group = [peak_indices[0]]

        for i in range(1, len(peak_indices)):
            if peak_indices[i] - peak_indices[i-1] <= 5:
                current_group.append(peak_indices[i])
            else:
                groups.append(current_group)
                current_group = [peak_indices[i]]
        groups.append(current_group)

        for group in groups[:10]:
            group_fft = fft_db[group]
            max_idx = group[np.argmax(group_fft)]
            signals.append({
                "frequencyHz": float(freqs[max_idx]),
                "frequencyMHz": float(freqs[max_idx] / 1e6),
                "powerDb": float(fft_db[max_idx]),
                "snrDb": float(fft_db[max_idx] - noise_floor),
                "bandwidthHz": float((freqs[group[-1]] - freqs[group[0]])),
                "timestamp": datetime.now(timezone.utc).isoformat(),
                "sensorType": "RTL_SDR",
                "sensorModel": "LOCAL_BRIDGE"
            })

        return signals

    def _classify_signal(self, signal_data):
        try:
            resp = requests.post(self.classify_url, json=signal_data, timeout=5)
            result = resp.json()

            freq_mhz = signal_data["frequencyMHz"]
            power = signal_data["powerDb"]
            snr = signal_data["snrDb"]

            if result.get("classification") and result["classification"] != "UNKNOWN":
                log.warning(
                    f"[SDR] SIGNAL CLASSIFIED: {result['classification']} "
                    f"at {freq_mhz:.3f}MHz (power={power:.1f}dB, SNR={snr:.1f}dB)"
                )
            else:
                log.info(
                    f"[SDR] Signal detected: {freq_mhz:.3f}MHz "
                    f"(power={power:.1f}dB, SNR={snr:.1f}dB) — unclassified"
                )

        except requests.exceptions.RequestException as e:
            log.error(f"[SDR] Server connection failed: {e}")
        except Exception as e:
            log.error(f"[SDR] Classification error: {e}")

    def stop(self):
        if self.sdr:
            self.sdr.close()
            log.info("[SDR] RTL-SDR closed")


class FrequencyScanner:
    BANDS = [
        {"name": "433MHz ISM", "center": 433.92e6, "desc": "IoT, remotes, some drones"},
        {"name": "868MHz EU ISM", "center": 868e6, "desc": "LoRa, EU IoT"},
        {"name": "915MHz US ISM", "center": 915e6, "desc": "LoRa, US IoT"},
        {"name": "1090MHz ADS-B", "center": 1090e6, "desc": "Aircraft transponders"},
        {"name": "2.4GHz ISM", "center": 2.4e9, "desc": "WiFi, Bluetooth, most consumer drones"},
    ]

    def __init__(self, server_url, dwell_time=5):
        self.server_url = server_url
        self.dwell_time = dwell_time

    def scan_all_bands(self, sdr_sensor):
        log.info(f"[SCANNER] Starting frequency sweep — {len(self.BANDS)} bands, {self.dwell_time}s dwell")
        for band in self.BANDS:
            if not RUNNING:
                break
            log.info(f"[SCANNER] Scanning {band['name']} ({band['center']/1e6:.1f}MHz) — {band['desc']}")
            sdr_sensor.sdr.center_freq = band["center"]
            sdr_sensor.center_freq = band["center"]
            time.sleep(0.1)

            end_time = time.time() + self.dwell_time
            while RUNNING and time.time() < end_time:
                try:
                    samples = sdr_sensor.sdr.read_samples(256 * 1024)
                    fft_data = np.abs(np.fft.fftshift(np.fft.fft(samples)))
                    fft_db = 20 * np.log10(fft_data + 1e-10)
                    freqs = np.fft.fftshift(
                        np.fft.fftfreq(len(samples), 1.0 / sdr_sensor.sample_rate)
                    ) + band["center"]

                    noise_floor = np.median(fft_db)
                    threshold = noise_floor + 15
                    peak_indices = np.where(fft_db > threshold)[0]

                    if len(peak_indices) > 0:
                        signals = sdr_sensor._extract_signals(freqs, fft_db, peak_indices, noise_floor)
                        for sig in signals:
                            sig["band"] = band["name"]
                            sdr_sensor._classify_signal(sig)

                    time.sleep(1)
                except Exception as e:
                    log.error(f"[SCANNER] Error on {band['name']}: {e}")
                    break

        log.info("[SCANNER] Frequency sweep complete")


def main():
    parser = argparse.ArgumentParser(
        description="SHIELD/ATLAS Local Sensor Bridge — Connect hardware to the cloud platform"
    )
    parser.add_argument(
        "--server", required=True,
        help="SHIELD/ATLAS server URL (e.g., https://shield-atlas-production.up.railway.app)"
    )
    parser.add_argument("--mic-only", action="store_true", help="Run acoustic sensor only (no SDR)")
    parser.add_argument("--sdr-only", action="store_true", help="Run SDR sensor only (no mic)")
    parser.add_argument("--sdr-freq", type=float, default=2.4e9, help="SDR center frequency in Hz (default: 2.4GHz)")
    parser.add_argument("--sdr-gain", default="auto", help="SDR gain (default: auto)")
    parser.add_argument("--sdr-rate", type=float, default=2.4e6, help="SDR sample rate (default: 2.4MS/s)")
    parser.add_argument("--mic-device", type=int, default=None, help="Microphone device index")
    parser.add_argument("--scan", action="store_true", help="Enable frequency sweep across common drone bands")
    parser.add_argument("--scan-dwell", type=int, default=5, help="Seconds to dwell on each frequency band")

    args = parser.parse_args()

    print()
    print("=" * 60)
    print("  SHIELD/ATLAS — Local Sensor Bridge v1.0.0")
    print("  Integrated Security Solutions (ISS) — SDVOSB")
    print("=" * 60)
    print(f"  Server:  {args.server}")
    print(f"  Mode:    {'MIC ONLY' if args.mic_only else 'SDR ONLY' if args.sdr_only else 'FULL (MIC + SDR)'}")
    if not args.mic_only:
        print(f"  SDR:     {args.sdr_freq/1e6:.1f}MHz, {args.sdr_rate/1e6:.1f}MS/s, gain={args.sdr_gain}")
    print("=" * 60)
    print()

    try:
        resp = requests.get(f"{args.server.rstrip('/')}/api/cram/status", timeout=10)
        if resp.status_code == 200:
            log.info(f"[BRIDGE] Connected to SHIELD/ATLAS at {args.server}")
        else:
            log.warning(f"[BRIDGE] Server responded with {resp.status_code} — continuing anyway")
    except Exception as e:
        log.error(f"[BRIDGE] Cannot reach server at {args.server}: {e}")
        log.error("[BRIDGE] Check server URL and try again")
        sys.exit(1)

    acoustic = None
    sdr = None

    if not args.sdr_only:
        try:
            acoustic = AcousticSensor(
                server_url=args.server,
                device_index=args.mic_device,
            )
            acoustic.start()
        except ImportError:
            log.error("[ACOUSTIC] sounddevice not installed — run: pip install sounddevice")
            if args.mic_only:
                sys.exit(1)
        except Exception as e:
            log.error(f"[ACOUSTIC] Failed to start microphone: {e}")
            if args.mic_only:
                sys.exit(1)

    if not args.mic_only:
        try:
            from rtlsdr import RtlSdr

            sdr = SDRSensor(
                server_url=args.server,
                center_freq=args.sdr_freq,
                sample_rate=args.sdr_rate,
                gain=args.sdr_gain,
            )

            sdr.sdr = RtlSdr()
            sdr.sdr.sample_rate = sdr.sample_rate
            sdr.sdr.center_freq = sdr.center_freq
            if sdr.gain == "auto":
                sdr.sdr.gain = "auto"
            else:
                sdr.sdr.gain = float(sdr.gain)
            log.info(f"[SDR] RTL-SDR INITIALIZED — device ready")

            if args.scan:
                scanner = FrequencyScanner(args.server, dwell_time=args.scan_dwell)
                sdr_thread = threading.Thread(target=scanner.scan_all_bands, args=(sdr,), daemon=True)
            else:
                sdr_thread = threading.Thread(target=sdr.start, daemon=True)
            sdr_thread.start()

        except ImportError:
            log.error("[SDR] pyrtlsdr not installed — run: pip install pyrtlsdr")
            if args.sdr_only:
                sys.exit(1)
        except Exception as e:
            log.error(f"[SDR] Failed to start RTL-SDR: {e}")
            log.info("[SDR] Is the SDR dongle plugged in? Check with: rtl_test")
            if args.sdr_only:
                sys.exit(1)

    log.info("[BRIDGE] All sensors active — press Ctrl+C to stop")

    while RUNNING:
        time.sleep(1)

    log.info("[BRIDGE] Shutting down...")
    if acoustic:
        acoustic.stop()
    if sdr:
        sdr.stop()
    log.info("[BRIDGE] Sensor bridge stopped")


if __name__ == "__main__":
    main()
