Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions csrc/cache/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ StaticKVCache::StaticKVCache(
k_dim_},
dtype_,
rank_info.device);
set_zeros(k_caches_);

// Allocate V cache
v_caches_ = infinicore::Tensor::empty(
Expand All @@ -77,6 +78,9 @@ StaticKVCache::StaticKVCache(
v_dim_},
dtype_,
rank_info.device);
set_zeros(v_caches_);

infinicore::context::syncStream();
}

infinicore::Tensor StaticKVCache::create_layer_kv_cache(
Expand Down Expand Up @@ -110,6 +114,9 @@ infinicore::Tensor StaticKVCache::create_layer_kv_cache(
kv_dim},
dtype,
rank_info.device);
set_zeros(kv_cache);

infinicore::context::syncStream();

return kv_cache;
}
Expand Down Expand Up @@ -211,6 +218,7 @@ PagedKVCache::PagedKVCache(
k_dim_},
dtype_,
rank_info.device);
set_zeros(k_caches_);
Comment thread
qinyiqun marked this conversation as resolved.

// [num_layers, num_blocks, num_rank_v_heads, block_size, v_dim]
v_caches_ = infinicore::Tensor::empty(
Expand All @@ -221,6 +229,9 @@ PagedKVCache::PagedKVCache(
v_dim_},
dtype_,
rank_info.device);
set_zeros(v_caches_);

infinicore::context::syncStream();
}

infinicore::Tensor PagedKVCache::create_layer_kv_cache(
Expand Down Expand Up @@ -256,6 +267,9 @@ infinicore::Tensor PagedKVCache::create_layer_kv_cache(
kv_shape,
dtype,
rank_info.device);
set_zeros(kv_cache);

infinicore::context::syncStream();

return kv_cache;
}
Expand Down
6 changes: 3 additions & 3 deletions csrc/config/model_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ ModelConfig::ModelConfig(const std::string &path) {
this->quant_config = QuantConfig(config_json["quantization_config"]);
}

infinicore::quantization::QuantScheme
infinilm::quantization::QuantScheme
ModelConfig::get_quant_scheme() const {
if (quant_config.get_quant_scheme() != infinicore::quantization::QuantScheme::NONE) {
if (quant_config.get_quant_scheme() != infinilm::quantization::QuantScheme::NONE) {
return quant_config.get_quant_scheme();
} else {
return infinicore::quantization::QuantScheme::NONE;
return infinilm::quantization::QuantScheme::NONE;
}
}

Expand Down
6 changes: 3 additions & 3 deletions csrc/config/model_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,17 @@ class ModelConfig {
return quant_config;
}

std::shared_ptr<infinicore::quantization::BaseQuantization> get_quantization_method() const {
std::shared_ptr<infinilm::quantization::BaseQuantization> get_quantization_method() const {
return quant_config.get_quantization_method();
}

infinicore::DataType get_dtype() const;
infinicore::quantization::QuantScheme get_quant_scheme() const;
infinilm::quantization::QuantScheme get_quant_scheme() const;
std::shared_ptr<infinicore::nn::RoPE::ScalingConfig> get_rope_scaling() const;
void set_kv_quant_scheme(infinicore::DataType kv_cache_dtype) {
this->quant_config.set_kv_quant_scheme(kv_cache_dtype);
}
infinicore::quantization::KVQuantAlgo get_kv_quant_scheme() const {
infinilm::quantization::KVQuantAlgo get_kv_quant_scheme() const {
return quant_config.get_kv_quant_scheme();
}
infinicore::DataType get_kv_cache_dtype() const {
Expand Down
15 changes: 7 additions & 8 deletions csrc/config/quant_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,24 @@ QuantConfig::QuantConfig(const nlohmann::json &json) : quantization_config(json)
this->quantization_method = get_quantization_method();
}

std::shared_ptr<infinicore::quantization::BaseQuantization>
std::shared_ptr<infinilm::quantization::BaseQuantization>
QuantConfig::get_quantization_method() const {
if (quantization_config.is_null()) {
return std::make_shared<infinicore::quantization::NoneQuantization>(quantization_config); // Default case if no matching scheme
return std::make_shared<infinilm::quantization::NoneQuantization>(quantization_config); // Default case if no matching scheme
}

// Determine the quantization scheme from the JSON config
if (quantization_config["quant_method"] == "compressed-tensors") {
return std::make_shared<infinicore::quantization::CompressedTensors>(quantization_config);
return std::make_shared<infinilm::quantization::CompressedTensors>(quantization_config);
} else if (quantization_config["quant_method"] == "awq") {
return std::make_shared<infinicore::quantization::AWQ>(quantization_config);
return std::make_shared<infinilm::quantization::AWQ>(quantization_config);
} else if (quantization_config["quant_method"] == "gptq") {
// return std::make_shared<infinicore::quantization::GPTQ_QY>(quantization_config);
return std::make_shared<infinicore::quantization::GPTQ>(quantization_config);
return std::make_shared<infinilm::quantization::GPTQ>(quantization_config);
} else {
return std::make_shared<infinicore::quantization::NoneQuantization>(quantization_config);
return std::make_shared<infinilm::quantization::NoneQuantization>(quantization_config);
}
// Add other schemes as needed

return std::make_shared<infinicore::quantization::NoneQuantization>(quantization_config); // Default case if no matching scheme
return std::make_shared<infinilm::quantization::NoneQuantization>(quantization_config); // Default case if no matching scheme
}
} // namespace infinilm::config
20 changes: 10 additions & 10 deletions csrc/config/quant_config.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once
#include "../utils.hpp"
#include "infinicore/quantization.hpp"
#include "../layers/quantization/quantization.hpp"
#include "nlohmann/json.hpp"
#include <optional>
#include <spdlog/spdlog.h>
Expand All @@ -14,13 +14,13 @@ class QuantConfig {
QuantConfig() = default;
QuantConfig(const nlohmann::json &json);

std::shared_ptr<infinicore::quantization::BaseQuantization> get_quantization_method() const;
std::shared_ptr<infinilm::quantization::BaseQuantization> get_quantization_method() const;

infinicore::quantization::QuantScheme get_quant_scheme() const {
infinilm::quantization::QuantScheme get_quant_scheme() const {
if (quantization_method != nullptr) {
return quantization_method->get_quant_scheme();
} else {
return infinicore::quantization::QuantScheme::NONE;
return infinilm::quantization::QuantScheme::NONE;
}
}

Expand All @@ -29,22 +29,22 @@ class QuantConfig {
this->kv_cache_dtype_ = std::make_optional(kv_cache_dtype);
switch (kv_cache_dtype) {
case infinicore::DataType::I8: {
this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::INT8;
this->kv_quant_scheme = infinilm::quantization::KVQuantAlgo::INT8;
break;
}
default: {
spdlog::warn("Unsupported kv_cache_dtype: '{}', fallback to NONE", infinicore::toString(kv_cache_dtype));
this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE;
this->kv_quant_scheme = infinilm::quantization::KVQuantAlgo::NONE;
break;
}
}
} catch (const std::exception &e) {
spdlog::error("Failed to parse kv_cache_dtype '{}': {}", infinicore::toString(kv_cache_dtype), e.what());
this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE;
this->kv_quant_scheme = infinilm::quantization::KVQuantAlgo::NONE;
}
}

infinicore::quantization::KVQuantAlgo get_kv_quant_scheme() const {
infinilm::quantization::KVQuantAlgo get_kv_quant_scheme() const {
return kv_quant_scheme;
}

Expand All @@ -57,9 +57,9 @@ class QuantConfig {

private:
nlohmann::json quantization_config;
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization_method;
std::shared_ptr<infinilm::quantization::BaseQuantization> quantization_method;

infinicore::quantization::KVQuantAlgo kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE;
infinilm::quantization::KVQuantAlgo kv_quant_scheme = infinilm::quantization::KVQuantAlgo::NONE;
std::optional<infinicore::DataType> kv_cache_dtype_ = std::nullopt;
};

Expand Down
15 changes: 1 addition & 14 deletions csrc/engine/compiler/paged_compiler.cpp
Original file line number Diff line number Diff line change
@@ -1,20 +1,7 @@
#include "paged_compiler.hpp"
#include "../../global_state/global_state.hpp"
#include "../../utils.hpp"

namespace {
// Todo: replace with Tensor::zeros when it is available
inline void set_zeros(infinicore::Tensor &tensor) {
std::vector<uint8_t> zeros(tensor->nbytes(), 0);
infinicore::context::memcpyH2D(tensor->data(), zeros.data(), tensor->nbytes(), false);
}

inline void set_minus_one(infinicore::Tensor &tensor) {
// For int32 tensors, 0xFF bytes correspond to -1 in two's complement.
std::vector<uint8_t> minus_one(tensor->nbytes(), 0xFF);
infinicore::context::memcpyH2D(tensor->data(), minus_one.data(), tensor->nbytes(), false);
}

} // namespace
namespace infinilm::engine {
PagedCompiler::PagedCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier)
: GraphCompiler(model, barrier) {
Expand Down
43 changes: 0 additions & 43 deletions csrc/engine/infer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,6 @@ namespace infinilm::engine {
//------------------------------------------------------
// Constructor
//------------------------------------------------------
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
InferEngine::InferEngine(
const InfinilmModel::Config &config,
const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type,
const cache::CacheConfig *cache_config,
bool enable_graph_compiling,
backends::AttentionBackend attention_backend) // Changed parameter
: communication_group_(distributed_config, device_type),
legacy_model_config_(config),
attention_backend_(attention_backend) {
if (cache_config != nullptr) {
cache_config_ = cache_config->unique_copy();
}
// Create one RankWorker per rank
int world_size = communication_group_.get_world_size();
barrier_ = std::make_unique<RankBarrier>((size_t)world_size);
workers_.reserve(world_size);
for (int r = 0; r < world_size; ++r) {
workers_.emplace_back(std::make_unique<RankWorker>(
legacy_model_config_,
communication_group_.get_rank_info(r),
cache_config_ != nullptr ? cache_config_.get() : nullptr,
barrier_.get(),
enable_graph_compiling,
attention_backend_));
}

// Compile the model on all workers
this->compile();
}

InferEngine::InferEngine(
const std::string &config_str,
const distributed::DistConfig &distributed_config,
Expand Down
22 changes: 0 additions & 22 deletions csrc/engine/infer_engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "../config/model_config.hpp"
#include "../global_state/global_state.hpp"
#include "../models/infinilm_model.hpp"
#include "../models/llama_legacy/llama_config.hpp"
#include "distributed/distributed.hpp"
#include "infinicore/tensor.hpp"
#include "rank_barrier.hpp"
Expand All @@ -21,26 +20,6 @@ class InferEngine {
using Output = RankWorker::Output;

// Updated constructor: accept CacheConfig instead of CacheType
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
InferEngine(
const InfinilmModel::Config &config,
const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig *cache_config = nullptr,
bool enable_graph_compiling = false,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);

InferEngine(
const std::string &config_str,
const distributed::DistConfig &distributed_config = distributed::DistConfig(),
Expand Down Expand Up @@ -78,7 +57,6 @@ class InferEngine {
std::unique_ptr<RankBarrier> barrier_;
distributed::CommunicationGroup communication_group_;
std::unique_ptr<cache::CacheConfig> cache_config_;
const InfinilmModel::Config &legacy_model_config_ = InfinilmModel::Config();
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
backends::AttentionBackend attention_backend_ = backends::AttentionBackend::Default;
};
Expand Down
49 changes: 0 additions & 49 deletions csrc/engine/rank_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,46 +10,6 @@

namespace infinilm::engine {

/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
RankWorker::RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling,
backends::AttentionBackend attention_backend)
: legacy_model_config_(model_config),
rank_info_(rank_info),
attention_backend_(attention_backend),
enable_graph_compiling_(enable_graph_compiling),
job_cmd_(Command::INIT),
has_job_(false),
job_done_(false),
should_exit_(false),
init_done_(false),
rng_(std::random_device{}()),
barrier_(barrier) {
if (cache_config != nullptr) {
pending_cache_config_ = cache_config->unique_copy();
}
// start the thread
thread_ = std::thread(&RankWorker::thread_loop, this);

// Wait until the worker thread finishes initialization (model created)
std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&] { return init_done_; });
}

RankWorker::RankWorker(
std::shared_ptr<infinilm::global_state::InfinilmConfig> infinilm_config,
const distributed::RankInfo &rank_info,
Expand Down Expand Up @@ -269,15 +229,6 @@ void RankWorker::thread_loop() {
infinilm::global_state::initialize_infinilm_config(infinilm_config_);

// Create model using factory (may be expensive)
if (model_config_ == nullptr) {
// model_ = InfinilmModelFactory::createModel(
// legacy_model_config_,
// rank_info_,
// pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr,
// attention_backend_);
throw std::runtime_error("RankWorker::thread_loop(): the way of creating models using LlamaConfig is no longer supported !!!");
}

const std::string &model_type = model_config_->get<std::string>("model_type");
const auto &model_map = models::get_causal_lm_model_map();
auto it = model_map.find(model_type);
Expand Down
8 changes: 0 additions & 8 deletions csrc/engine/rank_worker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,6 @@ class RankWorker {
infinicore::Tensor output_ids;
};

RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling,
backends::AttentionBackend attention_backend);

RankWorker(std::shared_ptr<infinilm::global_state::InfinilmConfig> infinilm_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config,
Expand Down Expand Up @@ -118,7 +111,6 @@ class RankWorker {

private:
// Worker properties
const InfinilmModel::Config &legacy_model_config_ = InfinilmModel::Config();
std::shared_ptr<infinilm::global_state::InfinilmConfig> infinilm_config_;
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
engine::distributed::RankInfo rank_info_;
Expand Down
Loading