// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Licensed under the MIT License.

#include <iostream>
#include <fstream>
#include <filesystem>

#include "onnx_ctx_model_helper.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/framework/execution_provider.h"
#include "nv_execution_provider.h"

namespace onnxruntime {
extern TensorrtLogger& GetTensorrtLogger(bool verbose_log);

/*
 * Convert binary data to hex string
 */
std::string BinaryToHexString(const void* data, size_t size) {
  static const char hex_chars[] = "0123456789abcdef";
  const uint8_t* bytes = static_cast<const uint8_t*>(data);
  std::string result;
  result.reserve(size * 2);

  for (size_t i = 0; i < size; ++i) {
    result.push_back(hex_chars[(bytes[i] >> 4) & 0xF]);
    result.push_back(hex_chars[bytes[i] & 0xF]);
  }
  return result;
}

/*
 * Convert hex string back to binary
 */
std::vector<uint8_t> HexStringToBinary(const std::string& hex) {
  if (hex.size() % 2 != 0) {
    ORT_THROW("Hex string must have even length");
  }

  std::vector<uint8_t> result;
  result.reserve(hex.size() / 2);

  for (size_t i = 0; i < hex.size(); i += 2) {
    uint8_t byte = 0;

    // High nibble
    char c = hex[i];
    byte |= (c >= '0' && c <= '9') ? static_cast<uint8_t>((c - '0') << 4) : (c >= 'a' && c <= 'f') ? static_cast<uint8_t>((c - 'a' + 10) << 4)
                                                                        : (c >= 'A' && c <= 'F')   ? static_cast<uint8_t>((c - 'A' + 10) << 4)
                                                                                                   : 0;

    // Low nibble
    c = hex[i + 1];
    byte |= (c >= '0' && c <= '9') ? static_cast<uint8_t>(c - '0') : (c >= 'a' && c <= 'f') ? static_cast<uint8_t>(c - 'a' + 10)
                                                                 : (c >= 'A' && c <= 'F')   ? static_cast<uint8_t>(c - 'A' + 10)
                                                                                            : 0;

    result.push_back(byte);
  }
  return result;
}

/*
 *  Check whether the graph has the EP context contrib op.
 *  The op can contain the precompiled engine info for TRT EP to directly load the engine.
 *
 *  Note: Please see more details about "EPContext" contrib op in contrib_defs.cc
 */
bool GraphHasCtxNode(const GraphViewer& graph_viewer, size_t& node_idx) {
  for (int i = 0; i < graph_viewer.MaxNodeIndex(); ++i) {
    auto node = graph_viewer.GetNode(i);
    if (node != nullptr && node->OpType() == EPCONTEXT_OP) {
      node_idx = i;
      return true;
    }
  }
  return false;
}

const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer) {
  // find the top level graph
  const Graph* cur_graph = &graph_viewer.GetGraph();
  while (cur_graph->IsSubgraph()) {
    cur_graph = cur_graph->ParentGraph();
  }

  const Graph& main_graph = *cur_graph;
  return main_graph.ModelPath();
}

/*
 * Update ep_cache_context attribute of the EP context node with the given engine binary data
 */
void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
                                     char* engine_data,
                                     size_t size) {
  ONNX_NAMESPACE::GraphProto* graph_proto = model_proto->mutable_graph();
  ONNX_NAMESPACE::NodeProto* node_proto = graph_proto->mutable_node(0);

  for (int i = 0; i < node_proto->attribute_size(); ++i) {
    ONNX_NAMESPACE::AttributeProto* attribute_proto = node_proto->mutable_attribute(i);
    if (attribute_proto->name() == EP_CACHE_CONTEXT) {
      std::string engine_data_str = "";
      if (size > 0) {
        engine_data_str.assign(engine_data, size);
      }
      attribute_proto->set_s(engine_data_str);
    }
  }
}

/*
 * Create EP context node where engine information is embedded
 */
Status CreateCtxNode(const GraphViewer& graph_viewer,
                     Graph& graph_build,
                     const std::string engine_cache_path,
                     char* engine_data,
                     size_t size,
                     const int64_t embed_mode,
                     const std::string compute_capability,
                     const std::string onnx_model_path,
                     const std::string& ep_context_node_name,
                     int32_t trt_version) {
  // Get graph inputs and outputs
  std::vector<onnxruntime::NodeArg*> inputs, outputs;
  for (auto input : graph_viewer.GetInputs()) {
    auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto());
    inputs.push_back(&n_input);
  }

  for (auto output : graph_viewer.GetOutputs()) {
    auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto());
    outputs.push_back(&n_output);
  }

  // Create EP context node attributes
  auto attr_embed_mode = ONNX_NAMESPACE::AttributeProto::Create();
  auto attr_main_context = ONNX_NAMESPACE::AttributeProto::Create();
  auto attr_ep_cache_context = ONNX_NAMESPACE::AttributeProto::Create();
  auto attr_sdk_version = ONNX_NAMESPACE::AttributeProto::Create();
  auto attr_hw_architecture = ONNX_NAMESPACE::AttributeProto::Create();
  auto attr_onnx_filename = ONNX_NAMESPACE::AttributeProto::Create();
  auto attr_partition_name = ONNX_NAMESPACE::AttributeProto::Create();
  std::string engine_data_str = "";
  attr_main_context->set_name(MAIN_CONTEXT);
  attr_main_context->set_type(onnx::AttributeProto_AttributeType_INT);
  attr_main_context->set_i(0);  // we do not support a main context node but each has it's own engine payload
  attr_embed_mode->set_name(EMBED_MODE);
  attr_embed_mode->set_type(onnx::AttributeProto_AttributeType_INT);
  attr_embed_mode->set_i(embed_mode);
  attr_ep_cache_context->set_name(EP_CACHE_CONTEXT);
  attr_ep_cache_context->set_type(onnx::AttributeProto_AttributeType_STRING);
  if (embed_mode) {
    if (size > 0) {
      engine_data_str.assign(engine_data, size);
    }
    attr_ep_cache_context->set_s(engine_data_str);
  } else {
    std::string engine_cache_filename = std::filesystem::path(engine_cache_path).filename().string();
    attr_ep_cache_context->set_s(engine_cache_filename);
    std::fstream engine_cache_file(engine_cache_path, std::ios::binary | std::ios::out);
    if (engine_cache_file.is_open()) {
      engine_cache_file.write(engine_data, size);
      engine_cache_file.close();
    } else {
      return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                             "NvTensorRTRTX EP could not write cache to ", engine_cache_path);
    }
  }

  attr_hw_architecture->set_name(COMPUTE_CAPABILITY);
  attr_hw_architecture->set_type(onnx::AttributeProto_AttributeType_STRING);
  attr_hw_architecture->set_s(compute_capability);

  attr_partition_name->set_name(PARTITION_NAME);
  attr_partition_name->set_type(onnx::AttributeProto_AttributeType_STRING);
  attr_partition_name->set_s(ep_context_node_name);  // includes hash of the subgraph that was built

  attr_onnx_filename->set_name(ONNX_MODEL_FILENAME);
  attr_onnx_filename->set_type(onnx::AttributeProto_AttributeType_STRING);
  attr_onnx_filename->set_s(std::filesystem::path(onnx_model_path).filename().string());

  attr_sdk_version->set_name(SDK_VERSION);
  attr_sdk_version->set_type(onnx::AttributeProto_AttributeType_STRING);
  attr_sdk_version->set_s(std::to_string(trt_version));

  auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create();
  constexpr int num_attributes = 4;
  node_attributes->reserve(num_attributes);
  node_attributes->emplace(MAIN_CONTEXT, *attr_main_context);
  node_attributes->emplace(EMBED_MODE, *attr_embed_mode);
  node_attributes->emplace(EP_CACHE_CONTEXT, *attr_ep_cache_context);
  node_attributes->emplace(COMPUTE_CAPABILITY, *attr_hw_architecture);
  node_attributes->emplace(PARTITION_NAME, *attr_partition_name);
  node_attributes->emplace(ONNX_MODEL_FILENAME, *attr_onnx_filename);
  node_attributes->emplace(SDK_VERSION, *attr_sdk_version);

  // Create EP context node
  graph_build.AddNode(ep_context_node_name, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN);
  ORT_ENFORCE(graph_build.Resolve().IsOK());
  return Status::OK();
}

/*
 * Return the directory where the ep context model locates
 */
std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path) {
  if (ep_context_file_path.empty()) {
    return std::filesystem::path();
  }
  std::filesystem::path ctx_path(ep_context_file_path);
  if (std::filesystem::is_directory(ep_context_file_path)) {
    return ctx_path;
  } else {
    return ctx_path.parent_path();
  }
}

/*
 * Get "EP context" model path.
 *
 * Function logic:
 * If ep_context_file_path is provided,
 *     - If ep_context_file_path is a file, return "ep_context_file_path".
 *     - If ep_context_file_path is a directory, return "ep_context_file_path/original_model_name_ctx.onnx".
 * If ep_context_file_path is not provided,
 *     - Return "original_model_name_ctx.onnx".
 *
 * TRT EP has rules about context model path and engine cache path (see tensorrt_execution_provider.cc):
 * - If dump_ep_context_model_ and engine_cache_enabled_ is enabled, TRT EP will dump context model and save engine cache
 *   to the same directory provided by ep_context_file_path_. (i.e. engine_cache_path_ = ep_context_file_path_)
 *
 * Example 1:
 * ep_context_file_path = "/home/user/ep_context_model_directory"
 * original_model_path = "model.onnx"
 * => return "/home/user/ep_context_model_folder/model_ctx.onnx"
 *
 * Example 2:
 * ep_context_file_path = "my_ctx_model.onnx"
 * original_model_path = "model.onnx"
 * => return "my_ctx_model.onnx"
 *
 * Example 3:
 * ep_context_file_path = "/home/user2/ep_context_model_directory/my_ctx_model.onnx"
 * original_model_path = "model.onnx"
 * => return "/home/user2/ep_context_model_directory/my_ctx_model.onnx"
 *
 */
std::string GetCtxModelPath(const std::string& ep_context_file_path,
                            const std::string& original_model_path) {
  std::string ctx_model_path;

  if (!ep_context_file_path.empty() && !std::filesystem::is_directory(ep_context_file_path)) {
    ctx_model_path = ep_context_file_path;
  } else {
    std::filesystem::path model_path = original_model_path;
    std::filesystem::path model_name_stem = model_path.stem();  // model_name.onnx -> model_name
    std::string ctx_model_name = model_name_stem.string() + "_ctx.onnx";

    if (std::filesystem::is_directory(ep_context_file_path)) {
      std::filesystem::path model_directory = ep_context_file_path;
      ctx_model_path = model_directory.append(ctx_model_name).string();
    } else {
      ctx_model_path = ctx_model_name;
    }
  }
  return ctx_model_path;
}

bool IsAbsolutePath(const std::string& path_string) {
#ifdef _WIN32
  onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string);
  auto path = std::filesystem::path(ort_path_string.c_str());
  return path.is_absolute();
#else
  if (!path_string.empty() && path_string[0] == '/') {
    return true;
  }
  return false;
#endif
}

// Like "../file_path"
bool IsRelativePathToParentPath(const std::string& path_string) {
#ifdef _WIN32
  onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string);
  auto path = std::filesystem::path(ort_path_string.c_str());
  auto relative_path = path.lexically_normal().make_preferred().wstring();
  if (relative_path.find(L"..", 0) != std::string::npos) {
    return true;
  }
  return false;
#else
  if (!path_string.empty() && path_string.find("..", 0) != std::string::npos) {
    return true;
  }
  return false;
#endif
}

Status TensorRTCacheModelHandler::GetEpContextFromGraph(const Node& node) {
  auto& attrs = node.GetAttributes();

  const int64_t embed_mode = attrs.at(EMBED_MODE).i();
  // Only make path checks if model not provided as byte buffer
  bool make_secure_path_checks = ep_context_model_path_.empty();

  if (embed_mode) {
    // Get engine from byte stream.
    const std::string& context_binary = attrs.at(EP_CACHE_CONTEXT).s();
    *(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine(const_cast<char*>(context_binary.c_str()),
                                                                                                static_cast<size_t>(context_binary.length())));
    LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it";
    if (!(*trt_engine_)) {
      return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                             "Nv EP could not deserialize engine from binary data");
    }

    if (weight_stripped_engine_refit_) {
      const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s();
      auto status = NvExecutionProvider::RefitEngine(onnx_model_filename,
                                                     onnx_model_folder_path_,
                                                     make_secure_path_checks,
                                                     onnx_model_bytestream_,
                                                     onnx_model_bytestream_size_,
                                                     onnx_external_data_bytestream_,
                                                     onnx_external_data_bytestream_size_,
                                                     (*trt_engine_).get(),
                                                     detailed_build_log_);
      if (status != Status::OK()) {
        return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
      }
    }
  } else {
    // Get engine from cache file.
    std::string cache_path = attrs.at(EP_CACHE_CONTEXT).s();

    // For security purpose, in the case of running context model, TRT EP won't allow
    // engine cache path to be the relative path like "../file_path" or the absolute path.
    // It only allows the engine cache to be in the same directory or sub directory of the context model.
    if (IsAbsolutePath(cache_path)) {
      return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "For security purpose, the ep_cache_context attribute should be set with a relative path, but it is an absolute path:  " + cache_path);
    }
    if (IsRelativePathToParentPath(cache_path)) {
      return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "The file path in ep_cache_context attribute has '..'. For security purpose, it's not allowed to point outside the directory.");
    }

    // The engine cache and context model (current model) should be in the same directory
    std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_));
    auto engine_cache_path = ctx_model_dir.append(cache_path);
    LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string();

    if (!std::filesystem::exists(engine_cache_path)) {
      return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                             "Nv EP can't find engine cache: " + engine_cache_path.string() +
                                 ". Please make sure engine cache is in the same directory or sub-directory of context model.");
    }

    size_t file_length = 0;
    auto path_str = ToPathString(engine_cache_path.string());

    Env::MappedMemoryPtr engine_buf;
    const auto& env = GetDefaultEnv();
    ORT_RETURN_IF_ERROR(env.GetFileLength(path_str.c_str(), file_length));
    if (!file_length) {
      return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                             "Nv EP could not read engine from cache: " + engine_cache_path.string());
    }
    ORT_RETURN_IF_ERROR(env.MapFileIntoMemory(path_str.c_str(), 0, file_length, engine_buf));

    *(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine(engine_buf.get(), file_length));
    if (!(*trt_engine_)) {
      return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                             "Nv EP could not deserialize engine from cache: " + engine_cache_path.string());
    }
    LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] DeSerialized " + engine_cache_path.string();

    if (weight_stripped_engine_refit_) {
      const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s();
      std::string weight_stripped_engine_cache = engine_cache_path.string();
      auto status = NvExecutionProvider::RefitEngine(onnx_model_filename,
                                                     onnx_model_folder_path_,
                                                     make_secure_path_checks,
                                                     onnx_model_bytestream_,
                                                     onnx_model_bytestream_size_,
                                                     onnx_external_data_bytestream_,
                                                     onnx_external_data_bytestream_size_,
                                                     (*trt_engine_).get(),
                                                     detailed_build_log_);
      if (status != Status::OK()) {
        return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
      }
    }
  }
  return Status::OK();
}

/*
 * The sanity check for EP context contrib op.
 */
bool TensorRTCacheModelHandler::ValidateEPCtxNode(const Node& node) {
  auto& attrs = node.GetAttributes();

  // Show the warning if compute capability is not matched
  if (attrs.count(COMPUTE_CAPABILITY) > 0) {
    std::string model_compute_capability = attrs.at(COMPUTE_CAPABILITY).s();
    // Verify if engine was compiled with ampere+ hardware compatibility enabled
    if (model_compute_capability == "80+") {
      LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine is compatible to all Ampere+ GPU (except Jetson)";
      if (std::stoi(compute_capability_) < 80) {
        LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] However, this GPU doesn't match. The compute capability of the GPU: " << compute_capability_;
      }
    } else if (model_compute_capability != compute_capability_) {
      LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine was compiled for a different compatibility level and might not work or perform suboptimal";
      LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] The compute capability of the engine: " << model_compute_capability;
      LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] The compute capability of the GPU: " << compute_capability_;
    }
  }

  // "embed_mode" attr and "ep_cache_context" attr should be present
  assert(attrs.count(EMBED_MODE) > 0);
  assert(attrs.count(EP_CACHE_CONTEXT) > 0);

  const int64_t embed_mode = attrs.at(EMBED_MODE).i();
  if (embed_mode == 1) {
    // engine binary data
    // LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING;
  }

  return true;
}
}  // namespace onnxruntime
