//
// Syd: rock-solid application kernel
// src/cgroup.rs: Cgroup v2 management for resource limits
//
// Copyright (c) 2025 Ali Polatel <alip@chesswob.org>
// Based in part upon uutils sandbox-rs crate's lib/resources/cgroup.rs which is:
//   Copyright (c) 2025 Erick Jesus
//   SPDX-License-Identifier: MIT
//
// SPDX-License-Identifier: GPL-3.0

//! Cgroup v2 management for resource limits

use std::{
    borrow::Cow,
    fs::{create_dir_all, remove_dir, write},
    os::fd::{AsFd, AsRawFd, BorrowedFd, IntoRawFd, OwnedFd, RawFd},
};

use btoi::btoi;
use nix::{errno::Errno, fcntl::OFlag, unistd::Pid};

use crate::{
    compat::{openat2, OpenHow, ResolveFlag},
    err::err2no,
    fd::AT_BADFD,
    fs::readlinkat,
    io::{read_all, write_all},
    path::{XPath, XPathBuf},
};

const CGROUP_V2_ROOT: &[u8] = b"/sys/fs/cgroup";

/// Cgroup v2 resource limits configuration
#[derive(Debug, Copy, Clone, Default)]
pub struct CgroupConfig {
    /// Memory limit in bytes (e.g., 100MB)
    pub memory_limit: Option<u64>,
    /// CPU weight (100-10000, default 100)
    pub cpu_weight: Option<u32>,
    /// CPU quota in microseconds
    pub cpu_quota: Option<u64>,
    /// CPU period in microseconds (default 100000)
    pub cpu_period: Option<u64>,
    /// Max PIDs allowed
    pub max_pids: Option<u32>,
}

impl CgroupConfig {
    /// Create cgroup config with memory limit
    pub fn with_memory(limit: u64) -> Self {
        Self {
            memory_limit: Some(limit),
            ..Default::default()
        }
    }

    /// Create cgroup config with CPU quota
    pub fn with_cpu_quota(quota: u64, period: u64) -> Self {
        Self {
            cpu_quota: Some(quota),
            cpu_period: Some(period),
            ..Default::default()
        }
    }

    /// Set memory limit
    pub fn memory(mut self, limit: u64) -> Self {
        self.memory_limit = Some(limit);
        self
    }

    /// Set CPU quota
    pub fn cpu_quota(mut self, quota: u64, period: u64) -> Self {
        self.cpu_quota = Some(quota);
        self.cpu_period = Some(period);
        self
    }

    /// Set CPU limit by percentage (0-100)
    ///
    /// Panics if `percent` is not in range 0-100.
    #[expect(clippy::arithmetic_side_effects)]
    pub fn cpu_limit_percent(self, percent: u64) -> Self {
        assert!(
            (0..=100).contains(&percent),
            "BUG: Invalid CPU limit percentage"
        );
        let quota = percent * 1000; // percent * period/100 with period=100000
        let period = 100000;
        self.cpu_quota(quota, period)
    }

    /// Set maximum PIDs
    pub fn max_pids(mut self, max: u32) -> Self {
        self.max_pids = Some(max);
        self
    }

    /// Validate configuration
    pub fn validate(&self) -> Result<(), Errno> {
        if matches!(self.memory_limit, Some(0)) {
            return Err(Errno::EINVAL);
        }

        if self
            .cpu_weight
            .as_ref()
            .map(|weight| !(100..=10000).contains(weight))
            .unwrap_or(false)
        {
            return Err(Errno::EINVAL);
        }

        Ok(())
    }
}

/// Cgroup v2 interface
pub struct Cgroup(pub OwnedFd);

impl AsFd for Cgroup {
    fn as_fd(&self) -> BorrowedFd<'_> {
        self.0.as_fd()
    }
}

impl AsRawFd for Cgroup {
    fn as_raw_fd(&self) -> RawFd {
        self.0.as_raw_fd()
    }
}

impl IntoRawFd for Cgroup {
    fn into_raw_fd(self) -> RawFd {
        let fd = self.0.as_raw_fd();
        std::mem::forget(self);
        fd
    }
}

fn cgroup_root_path() -> XPathBuf {
    std::env::var_os("SYD_CGROUP_ROOT")
        .map(XPathBuf::from)
        .unwrap_or_else(|| XPathBuf::from(CGROUP_V2_ROOT))
}

impl Cgroup {
    /// Create new cgroup for a process
    pub fn new(name: &[u8]) -> Result<Self, Errno> {
        // TODO: open cgroup_root first and verify inode!
        let cgroup_path = cgroup_root_path().join(name);

        // Create cgroup directory.
        create_dir_all(&cgroup_path).map_err(|e| err2no(&e))?;
        ensure_controller_files(&cgroup_path)?;

        // Open cgroup directory without following symlinks.
        let how = OpenHow::new()
            .flags(OFlag::O_PATH | OFlag::O_DIRECTORY | OFlag::O_NOFOLLOW | OFlag::O_CLOEXEC)
            .resolve(ResolveFlag::RESOLVE_NO_MAGICLINKS | ResolveFlag::RESOLVE_NO_SYMLINKS);
        #[expect(clippy::disallowed_methods)]
        let fd = openat2(AT_BADFD, &cgroup_path, how)?;

        Ok(Self(fd))
    }

    /// Apply configuration to cgroup
    pub fn apply_config(&self, config: &CgroupConfig) -> Result<(), Errno> {
        config.validate()?;

        if let Some(memory) = config.memory_limit {
            self.set_memory_limit(memory)?;
        }

        if let Some(weight) = config.cpu_weight {
            self.set_cpu_weight(weight)?;
        }

        if let Some(quota) = config.cpu_quota {
            let period = config.cpu_period.unwrap_or(100000);
            self.set_cpu_quota(quota, period)?;
        }

        if let Some(max_pids) = config.max_pids {
            self.set_max_pids(max_pids)?;
        }

        Ok(())
    }

    /// Add process to cgroup
    pub fn add_process(&self, pid: Pid) -> Result<(), Errno> {
        let mut buf = itoa::Buffer::new();
        self.write_file(b"cgroup.procs", buf.format(pid.as_raw()).as_bytes())
    }

    /// Set memory limit
    fn set_memory_limit(&self, limit: u64) -> Result<(), Errno> {
        let mut buf = itoa::Buffer::new();
        self.write_file(b"memory.max", buf.format(limit).as_bytes())
    }

    /// Set CPU weight
    fn set_cpu_weight(&self, weight: u32) -> Result<(), Errno> {
        let mut buf = itoa::Buffer::new();
        self.write_file(b"cpu.weight", buf.format(weight).as_bytes())
    }

    /// Set CPU quota (microseconds)
    fn set_cpu_quota(&self, quota: u64, period: u64) -> Result<(), Errno> {
        let quota_str = if quota == u64::MAX {
            Cow::Borrowed(&b"max"[..])
        } else {
            let mut buf = itoa::Buffer::new();
            let mut vec: Vec<u8> = Vec::new();
            vec.extend(buf.format(quota).as_bytes());
            vec.push(b' ');
            vec.extend(buf.format(period).as_bytes());
            Cow::Owned(vec)
        };
        self.write_file(b"cpu.max", &quota_str)
    }

    /// Set max PIDs
    fn set_max_pids(&self, max_pids: u32) -> Result<(), Errno> {
        let mut buf = itoa::Buffer::new();
        self.write_file(b"pids.max", buf.format(max_pids).as_bytes())
    }

    /// Read memory usage
    pub fn get_memory_usage(&self) -> Result<u64, Errno> {
        self.read_file_u64(b"memory.current")
    }

    /// Read memory limit
    pub fn get_memory_limit(&self) -> Result<u64, Errno> {
        self.read_file_u64(b"memory.max")
    }

    /// Read CPU usage in microseconds
    pub fn get_cpu_usage(&self) -> Result<u64, Errno> {
        let content = self
            .open_file(b"cpu.stat", OFlag::O_RDONLY)
            .and_then(read_all)?;

        // Parse "usage_usec 123456"
        for line in content.split(|&b| b == b'\n') {
            // Split by ASCII whitespace and skip empty fields.
            let mut parts = line
                .split(|b| b.is_ascii_whitespace())
                .filter(|s| !s.is_empty());
            if let Some(key) = parts.next() {
                if key == b"usage_usec" {
                    if let Some(value) = parts.next() {
                        return btoi::<u64>(value).or(Err(Errno::EINVAL));
                    }
                }
            }
        }

        Ok(0)
    }

    /// Check if cgroup exists
    pub fn exists(&self) -> bool {
        self.path()
            .map(|path| !path.ends_with(b" (deleted)"))
            .unwrap_or(false)
    }

    /// Get the path of this cgroup
    pub fn path(&self) -> Result<XPathBuf, Errno> {
        let mut path = XPathBuf::from("/proc/thread-self/fd");
        path.push_fd(self.0.as_raw_fd());
        readlinkat(AT_BADFD, &path)
    }

    /// Delete cgroup
    pub fn delete(&self) -> Result<(), Errno> {
        // Just try to remove and handle the error.
        remove_dir(&self.path()?).map_err(|e| err2no(&e))
    }

    fn write_file(&self, name: &[u8], content: &[u8]) -> Result<(), Errno> {
        self.open_file(name, OFlag::O_WRONLY)
            .and_then(|file| write_all(file, content))
    }

    fn read_file_u64(&self, name: &[u8]) -> Result<u64, Errno> {
        let content = self.read_file(name)?;
        let content = content.trim_ascii();
        btoi::<u64>(content).or(Err(Errno::EINVAL))
    }

    fn read_file(&self, name: &[u8]) -> Result<Vec<u8>, Errno> {
        self.open_file(name, OFlag::O_RDONLY).and_then(read_all)
    }

    fn open_file(&self, name: &[u8], flags: OFlag) -> Result<OwnedFd, Errno> {
        // Validate name is a safe filename.
        let name = XPath::from_bytes(name);
        if name.contains_char(b'/') || name.is_dot() || name.has_parent_dot() {
            return Err(Errno::EINVAL);
        }

        // Open cgroup file without following symlinks.
        let how = OpenHow::new()
            .flags(flags | OFlag::O_NOFOLLOW | OFlag::O_CLOEXEC)
            .resolve(
                ResolveFlag::RESOLVE_BENEATH
                    | ResolveFlag::RESOLVE_NO_MAGICLINKS
                    | ResolveFlag::RESOLVE_NO_SYMLINKS,
            );
        #[expect(clippy::disallowed_methods)]
        openat2(&self.0, name, how)
    }

    #[cfg(test)]
    pub(crate) fn for_testing(path: XPathBuf) -> Self {
        use nix::{fcntl::open, sys::stat::Mode};

        let fd = open(
            &path,
            OFlag::O_PATH | OFlag::O_DIRECTORY | OFlag::O_CLOEXEC,
            Mode::empty(),
        )
        .unwrap();

        Self(fd)
    }
}

fn ensure_controller_files(path: &XPath) -> Result<(), Errno> {
    const FILES: &[(&[u8], &str)] = &[
        (b"memory.max", "max"),
        (b"memory.current", "0"),
        (b"cpu.weight", "100"),
        (b"cpu.max", "max 100000"),
        (b"cpu.stat", "usage_usec 0\n"),
        (b"pids.max", "max"),
        (b"cgroup.procs", ""),
    ];

    for (name, default_content) in FILES {
        let file_path = path.join(name);
        if !file_path.exists(false) {
            write(&file_path, default_content).map_err(|e| err2no(&e))?;
        }
    }

    Ok(())
}

impl Drop for Cgroup {
    fn drop(&mut self) {
        // Clean up cgroup on drop (best effort)
        let _ = self.delete();
    }
}

#[cfg(test)]
mod tests {
    use std::{
        env,
        fs::{read_dir, read_to_string, remove_file},
    };

    use tempfile::tempdir;

    use super::*;

    fn prepare_cgroup_dir() -> (tempfile::TempDir, XPathBuf) {
        let tmp = tempdir().unwrap();
        let path = tmp.path().join("cgroup-test");
        create_dir_all(&path).unwrap();
        for file in &[
            "memory.max",
            "memory.current",
            "cpu.weight",
            "cpu.max",
            "cpu.stat",
            "pids.max",
            "cgroup.procs",
        ] {
            let file_path = path.join(file);
            if let Some(parent) = file_path.parent() {
                create_dir_all(parent).unwrap();
            }
            write(&file_path, "0").unwrap();
        }
        write(path.join("cpu.stat"), "usage_usec 0\n").unwrap();
        write(path.join("memory.current"), "0\n").unwrap();
        (tmp, path.into())
    }

    #[test]
    fn test_cgroup_config_default() {
        let config = CgroupConfig::default();
        assert!(config.memory_limit.is_none());
        assert!(config.cpu_weight.is_none());
    }

    #[test]
    fn test_cgroup_config_with_memory() {
        let config = CgroupConfig::with_memory(100 * 1024 * 1024);
        assert_eq!(config.memory_limit, Some(100 * 1024 * 1024));
    }

    #[test]
    fn test_cgroup_config_with_cpu_quota() {
        let config = CgroupConfig::with_cpu_quota(50000, 100000);
        assert_eq!(config.cpu_quota, Some(50000));
        assert_eq!(config.cpu_period, Some(100000));
    }

    #[test]
    fn test_cgroup_config_validate() {
        let config = CgroupConfig::default();
        assert!(config.validate().is_ok());

        let bad_config = CgroupConfig {
            memory_limit: Some(0),
            ..Default::default()
        };
        assert!(bad_config.validate().is_err());

        let bad_cpu_config = CgroupConfig {
            cpu_weight: Some(50),
            ..Default::default()
        };
        assert!(bad_cpu_config.validate().is_err());

        let good_cpu_config = CgroupConfig {
            cpu_weight: Some(100),
            ..Default::default()
        };
        assert!(good_cpu_config.validate().is_ok());
    }

    #[test]
    fn test_cgroup_path_creation() {
        // This test may only work if running as root and cgroup v2 is available
        // We'll test the logic without actually creating cgroups
        let test_path = XPath::from_bytes(CGROUP_V2_ROOT);
        if test_path.exists(true) {
            // Cgroup v2 is available
            let result = Cgroup::new(b"sandbox-test-delete-me");
            // Don't assert, as it may fail due to permissions
            let _ = result;
        }
    }

    #[test]
    fn test_cgroup_apply_config_writes_files() {
        let (_tmp, path) = prepare_cgroup_dir();
        let cgroup = Cgroup::for_testing(path.clone());

        let config = CgroupConfig {
            memory_limit: Some(2048),
            cpu_weight: Some(500),
            cpu_quota: Some(50_000),
            cpu_period: Some(100_000),
            max_pids: Some(32),
        };

        cgroup.apply_config(&config).unwrap();

        assert_eq!(
            read_to_string(path.join(b"memory.max")).unwrap().trim(),
            "2048"
        );
        assert_eq!(
            read_to_string(path.join(b"cpu.weight")).unwrap().trim(),
            "500"
        );
        assert_eq!(
            read_to_string(path.join(b"cpu.max")).unwrap().trim(),
            "50000 100000"
        );
        assert_eq!(read_to_string(path.join(b"pids.max")).unwrap().trim(), "32");
    }

    #[test]
    fn test_cgroup_add_process_writes_pid() {
        let (_tmp, path) = prepare_cgroup_dir();
        let cgroup = Cgroup::for_testing(path.clone());

        cgroup.add_process(Pid::from_raw(1234)).unwrap();
        assert_eq!(read_to_string(path.join(b"cgroup.procs")).unwrap(), "1234");
    }

    #[test]
    fn test_cgroup_resource_readers() {
        let (_tmp, path) = prepare_cgroup_dir();
        write(path.join(b"memory.current"), "4096").unwrap();
        write(path.join(b"cpu.stat"), "usage_usec 900\n").unwrap();
        let cgroup = Cgroup::for_testing(path.clone());

        assert_eq!(cgroup.get_memory_usage().unwrap(), 4096);
        assert_eq!(cgroup.get_cpu_usage().unwrap(), 900);
    }

    #[test]
    fn test_cgroup_delete_removes_directory() {
        let (tmp, path) = prepare_cgroup_dir();
        let cgroup = Cgroup::for_testing(path.clone());
        assert!(path.exists(true));
        for entry in read_dir(&path).unwrap() {
            let entry = entry.unwrap();
            if entry.path().is_file() {
                remove_file(entry.path()).unwrap();
            }
        }
        cgroup.delete().unwrap();
        assert!(!path.exists(true));
        drop(tmp);
    }

    #[test]
    fn test_cgroup_new_uses_env_override() {
        let tmp = tempdir().unwrap();
        let prev = env::var("SYD_CGROUP_ROOT").ok();
        unsafe {
            env::set_var("SYD_CGROUP_ROOT", tmp.path());
        }

        let cg = Cgroup::new(b"env-test").unwrap();
        assert!(cg.exists());
        assert!(tmp.path().join("env-test").exists());

        if let Some(value) = prev {
            unsafe {
                env::set_var("SYD_CGROUP_ROOT", value);
            }
        } else {
            unsafe {
                env::remove_var("SYD_CGROUP_ROOT");
            }
        }
    }

    #[test]
    fn cgroup_config_combines_multiple_limits() {
        let mut config = CgroupConfig::with_memory(256 * 1024 * 1024);
        config.cpu_weight = Some(500);
        config.cpu_quota = Some(50_000);
        config.cpu_period = Some(100_000);
        config.max_pids = Some(32);

        assert!(config.validate().is_ok());
        assert_eq!(config.memory_limit, Some(256 * 1024 * 1024));
        assert_eq!(config.cpu_weight, Some(500));
        assert_eq!(config.max_pids, Some(32));
    }

    #[test]
    fn cgroup_config_rejects_invalid_values() {
        let bad_memory = CgroupConfig {
            memory_limit: Some(0),
            ..Default::default()
        };
        assert!(bad_memory.validate().is_err());

        let bad_weight_low = CgroupConfig {
            cpu_weight: Some(10),
            ..Default::default()
        };
        assert!(bad_weight_low.validate().is_err());

        let bad_weight_high = CgroupConfig {
            cpu_weight: Some(20_000),
            ..Default::default()
        };
        assert!(bad_weight_high.validate().is_err());
    }

    #[test]
    fn cgroup_config_helpers_set_expected_fields() {
        let memory = CgroupConfig::with_memory(64 * 1024 * 1024);
        assert_eq!(memory.memory_limit, Some(64 * 1024 * 1024));

        let quota = CgroupConfig::with_cpu_quota(100_000, 200_000);
        assert_eq!(quota.cpu_quota, Some(100_000));
        assert_eq!(quota.cpu_period, Some(200_000));
    }
}
