//! Jump table implementation.
//!
//! This module defines the `Jump` enum and the `JumpTable` struct, which are used to manage
//! various types of jumps in the program, including offsets, labels, function calls, and
//! external functions.

use crate::codegen::ExtFunc;
use core::fmt::Display;
pub use table::JumpTable;

mod pc;
mod relocate;
mod table;
mod target;

/// Represents the different types of jumps in the program.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Jump {
    /// Jump to a specific label, which corresponds to the original program counter.
    Label(u16),
    /// Jump to a function identified by its index.
    Func(u32),
    /// Jump to an external function.
    ExtFunc(ExtFunc),
}

impl Display for Jump {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        match self {
            Jump::Label(offset) => write!(f, "Label(0x{offset:x})"),
            Jump::Func(index) => write!(f, "Func({index})"),
            Jump::ExtFunc(_) => write!(f, "ExtFunc"),
        }
    }
}

impl Jump {
    /// Checks if the target is a label.
    pub fn is_label(&self) -> bool {
        matches!(self, Jump::Label { .. })
    }

    /// Checks if the target is a function call.
    pub fn is_call(&self) -> bool {
        !self.is_label()
    }
}

#[cfg(test)]
mod tests {
    use crate::jump::{Jump, JumpTable};
    use smallvec::smallvec;

    #[allow(unused)]
    fn init_tracing() {
        tracing_subscriber::fmt()
            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
            .without_time()
            .compact()
            .try_init()
            .ok();
    }

    fn assert_target_shift_vs_relocation(mut table: JumpTable) -> anyhow::Result<()> {
        // Calculate expected buffer size based on the maximum target
        let mut buffer = smallvec![0; table.max_target() as usize];

        // Perform target shifts
        table.shift_targets()?;

        // Find the maximum target after shifts
        let max_target = table.max_target();

        // Perform relocation
        table.relocate(&mut buffer)?;

        assert_eq!(buffer.len(), max_target as usize);
        Ok(())
    }

    #[test]
    fn test_target_shift_vs_relocation() -> anyhow::Result<()> {
        let mut table = JumpTable::default();

        // Register jumps with known offsets and labels
        table.register(0x10, Jump::Label(0x20)); // Jump to label at 0x20
        table.register(0x30, Jump::Label(0x40)); // Jump to label at 0x40

        assert_target_shift_vs_relocation(table)
    }

    #[test]
    fn test_multiple_internal_calls() -> anyhow::Result<()> {
        let mut table = JumpTable::default();

        // Simulate multiple functions calling _approve
        table.register(0x10, Jump::Label(0x100)); // approve() -> _approve
        table.register(0x20, Jump::Label(0x100)); // spend_allowance() -> _approve

        assert_target_shift_vs_relocation(table)
    }

    #[test]
    fn test_nested_function_calls() -> anyhow::Result<()> {
        let mut table = JumpTable::default();

        // Simulate ERC20's approve -> _approve call chain
        table.register(0x100, Jump::Label(0x200)); // approve entry
        table.register(0x110, Jump::Label(0x300)); // approve -> _approve
        table.register(0x200, Jump::Label(0x400)); // _approve entry

        let mut buffer = smallvec![0; table.max_target() as usize];
        table.relocate(&mut buffer)?;

        // Check if all jumps use correct PUSH instructions
        assert_eq!(buffer[0x100], 0x61); // PUSH2
        assert_eq!(buffer[0x113], 0x61); // PUSH2
        assert_eq!(buffer[0x206], 0x61); // PUSH2

        Ok(())
    }

    #[test]
    fn test_label_call_interaction() -> anyhow::Result<()> {
        init_tracing();
        let mut table = JumpTable::default();

        table.func.insert(1, 0x317);
        table.label(0x10, 0x12);
        table.call(0x11, 1);

        let mut buffer = smallvec![0; table.max_target() as usize];
        table.relocate(&mut buffer)?;

        assert_eq!(buffer[0x11], 0x17, "{buffer:?}");
        assert_eq!(buffer[0x14], 0x03, "{buffer:?}");
        assert_eq!(buffer[0x15], 0x1c, "{buffer:?}");
        Ok(())
    }

    #[test]
    fn test_large_target_offset_calculation() -> anyhow::Result<()> {
        let mut table = JumpTable::default();

        // Register a jump with target < 0xff
        table.register(0x10, Jump::Label(0x80));

        // Register a jump with target > 0xff
        table.register(0x20, Jump::Label(0x100));

        // Register a jump with target > 0xfff
        table.register(0x30, Jump::Label(0x1000));

        let mut buffer = smallvec![0; table.max_target() as usize];
        table.relocate(&mut buffer)?;

        // Check if offsets are correctly calculated
        // For target 0x80: PUSH1 (1 byte) + target (1 byte)
        // For target 0x100: PUSH2 (1 byte) + target (2 bytes)
        // For target 0x1000: PUSH2 (1 byte) + target (2 bytes)
        assert_eq!(buffer[0x11], 0x88); // Small target
        assert_eq!(buffer[0x23], 0x01); // First byte of large target
        assert_eq!(buffer[0x24], 0x08); // Second byte of large target
        assert_eq!(buffer[0x36], 0x10); // First byte of large target
        assert_eq!(buffer[0x37], 0x08); // Second byte of large target

        Ok(())
    }

    #[test]
    fn test_sequential_large_jumps() -> anyhow::Result<()> {
        let mut table = JumpTable::default();

        // Register multiple sequential jumps with increasing targets
        // This mirrors the ERC20 pattern where we have many functions
        for i in 0..20 {
            let target = 0x100 + (i * 0x20);
            table.register(0x10 + i, Jump::Label(target));
        }

        let mut buffer = smallvec![0; table.max_target() as usize];
        table.relocate(&mut buffer)?;

        // Check first jump (should use PUSH2)
        assert_eq!(buffer[0x10], 0x61); // PUSH2
        assert_eq!(buffer[0x11], 0x01); // First byte
        assert_eq!(buffer[0x12], 0x3c); // Second byte
        assert_eq!(0x013c, 0x100 + 20 * 3);

        // Check last jump (should still use PUSH2 but with adjusted offset)
        let last_idx = 0x10 + 19 + 19 * 3;
        assert_eq!(buffer[last_idx], 0x61); // PUSH2
        assert_eq!(buffer[last_idx + 1], 0x03); // First byte should be larger
        assert_eq!(buffer[last_idx + 2], 0x9c); // Second byte accounts for all previous jumps
        assert_eq!(0x039c, 0x100 + 0x20 * 19 + 20 * 3);

        Ok(())
    }

    #[test]
    fn test_dispatcher_jump_targets() -> anyhow::Result<()> {
        let mut table = JumpTable::default();
        let selectors = 5;

        // Register jumps for each selector check
        for i in 0..selectors {
            let i = i as u16;
            let check_pc = 0x10 + i * 0x20;
            let target_pc = 0x100 + i * 0x40;

            // Register both the comparison jump and function jump
            table.register(check_pc, Jump::Label(check_pc + 0x10));
            table.register(check_pc + 0x10, Jump::Label(target_pc));
        }

        let mut buffer = smallvec![0; table.max_target() as usize];
        table.relocate(&mut buffer)?;

        // Verify each selector's jump chain
        let mut total_offset = 0;
        for i in 0..selectors {
            let check_pc = 0x10 + i * 0x20 + total_offset;
            let check_pc_offset = if check_pc + 0x10 > 0xff { 3 } else { 2 };

            let func_pc = check_pc + 0x10 + check_pc_offset;

            let check_jump = buffer[check_pc];
            let func_jump = buffer[func_pc];

            assert_eq!(check_jump, if func_pc > 0xff { 0x61 } else { 0x60 });
            assert_eq!(func_jump, 0x61);

            // Update total offset for next iteration
            total_offset += check_pc_offset + 3;
        }

        Ok(())
    }
}