/*
 * Copyright (C) 2018-2024 Intel Corporation
 *
 * SPDX-License-Identifier: MIT
 *
 */

#include "shared/source/command_stream/aub_command_stream_receiver.h"

#include "shared/source/aub/aub_helper.h"
#include "shared/source/debug_settings/debug_settings_manager.h"
#include "shared/source/execution_environment/execution_environment.h"
#include "shared/source/execution_environment/root_device_environment.h"
#include "shared/source/helpers/basic_math.h"
#include "shared/source/helpers/debug_helpers.h"
#include "shared/source/helpers/gfx_core_helper.h"
#include "shared/source/helpers/hw_info.h"
#include "shared/source/helpers/options.h"
#include "shared/source/os_interface/os_inc_base.h"
#include "shared/source/os_interface/sys_calls_common.h"
#include "shared/source/release_helper/release_helper.h"

#include <algorithm>
#include <cstring>
#include <sstream>

namespace NEO {
AubCommandStreamReceiverCreateFunc aubCommandStreamReceiverFactory[IGFX_MAX_CORE] = {};

std::string AUBCommandStreamReceiver::createFullFilePath(const HardwareInfo &hwInfo, const std::string &filename, uint32_t rootDeviceIndex) {
    std::string hwPrefix = hardwarePrefix[hwInfo.platform.eProductFamily];

    // Generate the full filename
    const auto &gtSystemInfo = hwInfo.gtSystemInfo;
    std::stringstream strfilename;
    auto subDevicesCount = GfxCoreHelper::getSubDevicesCount(&hwInfo);
    uint32_t subSlicesPerSlice = gtSystemInfo.SubSliceCount / gtSystemInfo.SliceCount;
    strfilename << hwPrefix << "_";
    std::stringstream strExtendedFileName;

    strExtendedFileName << filename;
    if (debugManager.flags.GenerateAubFilePerProcessId.get()) {
        strExtendedFileName << "_PID_" << SysCalls::getProcessId();
    }
    auto releaseHelper = ReleaseHelper::create(hwInfo.ipVersion);
    const auto deviceConfig = AubHelper::getDeviceConfigString(releaseHelper.get(), subDevicesCount, gtSystemInfo.SliceCount, subSlicesPerSlice, gtSystemInfo.MaxEuPerSubSlice);
    strfilename << deviceConfig << "_" << rootDeviceIndex << "_" << strExtendedFileName.str() << ".aub";

    // clean-up any fileName issues because of the file system incompatibilities
    auto fileName = strfilename.str();
    for (char &i : fileName) {
        i = i == '/' ? '_' : i;
    }

    std::string filePath(folderAUB);
    if (debugManager.flags.AUBDumpCaptureDirPath.get() != "unk") {
        filePath.assign(debugManager.flags.AUBDumpCaptureDirPath.get());
    }

    filePath.append(Os::fileSeparator);
    filePath.append(fileName);

    return filePath;
}

CommandStreamReceiver *AUBCommandStreamReceiver::create(const std::string &baseName,
                                                        bool standalone,
                                                        ExecutionEnvironment &executionEnvironment,
                                                        uint32_t rootDeviceIndex,
                                                        const DeviceBitfield deviceBitfield) {
    auto hwInfo = executionEnvironment.rootDeviceEnvironments[rootDeviceIndex]->getHardwareInfo();
    std::string filePath = AUBCommandStreamReceiver::createFullFilePath(*hwInfo, baseName, rootDeviceIndex);
    if (debugManager.flags.AUBDumpCaptureFileName.get() != "unk") {
        filePath.assign(debugManager.flags.AUBDumpCaptureFileName.get());
    }

    if (hwInfo->platform.eRenderCoreFamily >= IGFX_MAX_CORE) {
        DEBUG_BREAK_IF(!false);
        return nullptr;
    }

    auto pCreate = aubCommandStreamReceiverFactory[hwInfo->platform.eRenderCoreFamily];
    return pCreate ? pCreate(filePath, standalone, executionEnvironment, rootDeviceIndex, deviceBitfield) : nullptr;
}
} // namespace NEO

namespace AubMemDump {
using CmdServicesMemTraceMemoryCompare = AubMemDump::CmdServicesMemTraceMemoryCompare;
using CmdServicesMemTraceMemoryWrite = AubMemDump::CmdServicesMemTraceMemoryWrite;
using CmdServicesMemTraceRegisterPoll = AubMemDump::CmdServicesMemTraceRegisterPoll;
using CmdServicesMemTraceRegisterWrite = AubMemDump::CmdServicesMemTraceRegisterWrite;
using CmdServicesMemTraceVersion = AubMemDump::CmdServicesMemTraceVersion;

static auto sizeMemoryWriteHeader = sizeof(CmdServicesMemTraceMemoryWrite) - sizeof(CmdServicesMemTraceMemoryWrite::data);

extern const size_t dwordCountMax;

void AubFileStream::open(const char *filePath) {
    fileHandle.open(filePath, std::ofstream::binary);
    fileName.assign(filePath);
}

void AubFileStream::close() {
    fileHandle.close();
    fileName.clear();
}

void AubFileStream::write(const char *data, size_t size) {
    fileHandle.write(data, size);
}

void AubFileStream::flush() {
    fileHandle.flush();
}

bool AubFileStream::init(uint32_t stepping, uint32_t device) {
    CmdServicesMemTraceVersion header = {};

    header.setHeader();
    header.dwordCount = (sizeof(header) / sizeof(uint32_t)) - 1;
    header.stepping = stepping;
    header.metal = 0;
    header.device = device;
    header.csxSwizzling = CmdServicesMemTraceVersion::CsxSwizzlingValues::Disabled;
    // Which recording method used:
    //  Phys is required for GGTT memory to be written directly to phys vs through aperture.
    header.recordingMethod = CmdServicesMemTraceVersion::RecordingMethodValues::Phy;
    header.pch = CmdServicesMemTraceVersion::PchValues::Default;
    header.captureTool = CmdServicesMemTraceVersion::CaptureToolValues::GenKmdCapture;
    header.primaryVersion = 0;
    header.secondaryVersion = 0;
    header.commandLine[0] = 'N';
    header.commandLine[1] = 'E';
    header.commandLine[2] = 'O';
    header.commandLine[3] = 0;

    write(reinterpret_cast<char *>(&header), sizeof(header));
    return true;
}

void AubFileStream::writeMemory(uint64_t physAddress, const void *memory, size_t size, uint32_t addressSpace, uint32_t hint) {
    writeMemoryWriteHeader(physAddress, size, addressSpace, hint);

    // Copy the contents from source to destination.
    write(reinterpret_cast<const char *>(memory), size);

    auto sizeRemainder = size % sizeof(uint32_t);
    if (sizeRemainder) {
        // if input size is not 4 byte aligned, write extra zeros to AUB
        uint32_t zero = 0;
        write(reinterpret_cast<char *>(&zero), sizeof(uint32_t) - sizeRemainder);
    }
}

void AubFileStream::writeMemoryWriteHeader(uint64_t physAddress, size_t size, uint32_t addressSpace, uint32_t hint) {
    CmdServicesMemTraceMemoryWrite header = {};
    auto alignedBlockSize = (size + sizeof(uint32_t) - 1) & ~(sizeof(uint32_t) - 1);
    auto dwordCount = (sizeMemoryWriteHeader + alignedBlockSize) / sizeof(uint32_t);
    DEBUG_BREAK_IF(dwordCount > AubMemDump::dwordCountMax);

    header.setHeader();
    header.dwordCount = static_cast<uint32_t>(dwordCount - 1);
    header.address = physAddress;
    header.repeatMemory = CmdServicesMemTraceMemoryWrite::RepeatMemoryValues::NoRepeat;
    header.tiling = CmdServicesMemTraceMemoryWrite::TilingValues::NoTiling;
    header.dataTypeHint = hint;
    header.addressSpace = addressSpace;
    header.dataSizeInBytes = static_cast<uint32_t>(size);

    write(reinterpret_cast<const char *>(&header), sizeMemoryWriteHeader);
}

void AubFileStream::writeGTT(uint32_t gttOffset, uint64_t entry) {
    write(reinterpret_cast<char *>(&entry), sizeof(entry));
}

void AubFileStream::writePTE(uint64_t physAddress, uint64_t entry, uint32_t addressSpace) {
    write(reinterpret_cast<char *>(&entry), sizeof(entry));
}

void AubFileStream::writeMMIOImpl(uint32_t offset, uint32_t value) {
    CmdServicesMemTraceRegisterWrite header = {};
    header.setHeader();
    header.dwordCount = (sizeof(header) / sizeof(uint32_t)) - 1;
    header.registerOffset = offset;
    header.messageSourceId = MessageSourceIdValues::Ia;
    header.registerSize = RegisterSizeValues::Dword;
    header.registerSpace = RegisterSpaceValues::Mmio;
    header.writeMaskLow = 0xffffffff;
    header.writeMaskHigh = 0x00000000;
    header.data[0] = value;

    write(reinterpret_cast<char *>(&header), sizeof(header));
}

void AubFileStream::registerPoll(uint32_t registerOffset, uint32_t mask, uint32_t value, bool pollNotEqual, uint32_t timeoutAction) {
    CmdServicesMemTraceRegisterPoll header = {};
    header.setHeader();
    header.registerOffset = registerOffset;
    header.timeoutAction = timeoutAction;
    header.pollNotEqual = pollNotEqual;
    header.operationType = CmdServicesMemTraceRegisterPoll::OperationTypeValues::Normal;
    header.registerSize = CmdServicesMemTraceRegisterPoll::RegisterSizeValues::Dword;
    header.registerSpace = CmdServicesMemTraceRegisterPoll::RegisterSpaceValues::Mmio;
    header.pollMaskLow = mask;
    header.data[0] = value;
    header.dwordCount = (sizeof(header) / sizeof(uint32_t)) - 1;

    write(reinterpret_cast<char *>(&header), sizeof(header));
}

void AubFileStream::expectMMIO(uint32_t mmioRegister, uint32_t expectedValue) {
    using AubMemDump::CmdServicesMemTraceRegisterCompare;
    CmdServicesMemTraceRegisterCompare header;
    memset(&header, 0, sizeof(header));
    header.setHeader();

    header.data[0] = expectedValue;
    header.registerOffset = mmioRegister;
    header.noReadExpect = CmdServicesMemTraceRegisterCompare::NoReadExpectValues::ReadExpect;
    header.registerSize = CmdServicesMemTraceRegisterCompare::RegisterSizeValues::Dword;
    header.registerSpace = CmdServicesMemTraceRegisterCompare::RegisterSpaceValues::Mmio;
    header.readMaskLow = 0xffffffff;
    header.readMaskHigh = 0xffffffff;
    header.dwordCount = (sizeof(header) / sizeof(uint32_t)) - 1;

    write(reinterpret_cast<char *>(&header), sizeof(header));
}

void AubFileStream::expectMemory(uint64_t physAddress, const void *memory, size_t sizeRemaining,
                                 uint32_t addressSpace, uint32_t compareOperation) {
    using CmdServicesMemTraceMemoryCompare = AubMemDump::CmdServicesMemTraceMemoryCompare;
    CmdServicesMemTraceMemoryCompare header = {};
    header.setHeader();

    header.noReadExpect = CmdServicesMemTraceMemoryCompare::NoReadExpectValues::ReadExpect;
    header.repeatMemory = CmdServicesMemTraceMemoryCompare::RepeatMemoryValues::NoRepeat;
    header.tiling = CmdServicesMemTraceMemoryCompare::TilingValues::NoTiling;
    header.crcCompare = CmdServicesMemTraceMemoryCompare::CrcCompareValues::NoCrc;
    header.compareOperation = compareOperation;
    header.dataTypeHint = CmdServicesMemTraceMemoryCompare::DataTypeHintValues::TraceNotype;
    header.addressSpace = addressSpace;

    auto headerSize = sizeof(CmdServicesMemTraceMemoryCompare) - sizeof(CmdServicesMemTraceMemoryCompare::data);
    auto blockSizeMax = dwordCountMax * sizeof(uint32_t) - headerSize;

    // We have to decompose memory into chunks that can be streamed per iteration
    while (sizeRemaining > 0) {
        AubMemDump::setAddress(header, physAddress);

        auto sizeThisIteration = std::min(sizeRemaining, blockSizeMax);

        // Round up to the number of dwords
        auto dwordCount = Math::divideAndRoundUp(headerSize + sizeThisIteration, sizeof(uint32_t));

        header.dwordCount = static_cast<uint32_t>(dwordCount - 1);
        header.dataSizeInBytes = static_cast<uint32_t>(sizeThisIteration);

        // Write the header
        write(reinterpret_cast<char *>(&header), headerSize);

        // Copy the contents from source to destination.
        write(reinterpret_cast<const char *>(memory), sizeThisIteration);

        sizeRemaining -= sizeThisIteration;
        memory = (uint8_t *)memory + sizeThisIteration;
        physAddress += sizeThisIteration;

        auto remainder = sizeThisIteration & (sizeof(uint32_t) - 1);
        if (remainder) {
            // if size is not 4 byte aligned, write extra zeros to AUB
            uint32_t zero = 0;
            write(reinterpret_cast<char *>(&zero), sizeof(uint32_t) - remainder);
        }
    }
}

void AubFileStream::createContext(const AubPpgttContextCreate &cmd) {
    write(reinterpret_cast<const char *>(&cmd), sizeof(cmd));
}

bool AubFileStream::addComment(const char *message) {
    using CmdServicesMemTraceComment = AubMemDump::CmdServicesMemTraceComment;
    CmdServicesMemTraceComment cmd = {};
    cmd.setHeader();
    cmd.syncOnComment = false;
    cmd.syncOnSimulatorDisplay = false;

    auto messageLen = strlen(message) + 1;
    auto dwordLen = ((messageLen + sizeof(uint32_t) - 1) & ~(sizeof(uint32_t) - 1)) / sizeof(uint32_t);
    cmd.dwordCount = static_cast<uint32_t>(dwordLen + 1);

    write(reinterpret_cast<char *>(&cmd), sizeof(cmd) - sizeof(cmd.comment));
    write(message, messageLen);
    auto remainder = messageLen & (sizeof(uint32_t) - 1);
    if (remainder) {
        // if size is not 4 byte aligned, write extra zeros to AUB
        uint32_t zero = 0;
        write(reinterpret_cast<char *>(&zero), sizeof(uint32_t) - remainder);
    }
    return true;
}

std::unique_lock<std::mutex> AubFileStream::lockStream() {
    return std::unique_lock<std::mutex>(mutex);
}

} // namespace AubMemDump
