Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tpu_raiden/api/torch/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ py_library(
visibility = ["//visibility:public"],
deps = [
"//tpu_raiden/frameworks/torch:_tpu_raiden_torch",
"@torch_tpu//shims/torch:pytorch",
"@torch_tpu//torch_tpu",
"@torch_tpu//torch_tpu:_loader",
],
Expand Down
47 changes: 44 additions & 3 deletions tpu_raiden/api/torch/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@

import os
import sys
from typing import Any, List, Tuple
from typing import Any, List, Optional, Tuple

import torch
import torch_tpu # noqa: F401 # Load torch shared libraries before the extension.
from torch_tpu import _loader as _torch_tpu_loader


Expand Down Expand Up @@ -77,6 +75,7 @@ def __init__(
num_slots: int,
timeout_s: float = 120.0,
unsafe_skip_buffer_lock: bool = True,
listener_port: Optional[int] = None,
):
"""Instantiates the TransferEngine-based KVCacheManager.

Expand All @@ -89,6 +88,8 @@ def __init__(
num_slots: Number of transfer slots to allocate.
timeout_s: Timeout in seconds for transfer operations.
unsafe_skip_buffer_lock: Skip dynamic safety locking.
listener_port: Sockets server port for incoming C++ KVCacheListener
commands.
"""
self._impl = _impl.KVCacheManager(
kv_caches=kv_caches,
Expand All @@ -98,6 +99,7 @@ def __init__(
num_slots=num_slots,
timeout_s=timeout_s,
unsafe_skip_buffer_lock=unsafe_skip_buffer_lock,
listener_port=listener_port,
)

@property
Expand All @@ -110,6 +112,11 @@ def local_control_port(self) -> int:
"""Returns the active control plane listener port."""
return self._impl.local_control_port

@property
def local_port(self) -> int:
"""Returns the active data port."""
return self._impl.local_port

def register_read(
self, req_id: str, uuid: int, block_ids: List[int]
) -> bool:
Expand Down Expand Up @@ -195,3 +202,37 @@ def h2d(
if copy_sizes is None:
copy_sizes = [1] * len(src_offsets)
return self._impl.H2d(src_offsets, dst_offsets, copy_sizes)

@property
def listener_port(self) -> Optional[int]:
"""Returns the active local port assigned to the C++ KVCacheListener."""
return self._impl.listener_port

@property
def is_listener_active(self) -> bool:
"""Returns whether the native C++ KVCacheListener is actively running."""
return self._impl.is_listener_active

def h2d(
self,
src_offsets: List[int] = None,
dst_offsets: List[int] = None,
sizes: List[int] = None,
) -> Any:
"""Triggers Host-to-Device (H2D) copy of staged host buffer to Device memory."""
src_offsets = src_offsets or []
dst_offsets = dst_offsets or []
sizes = sizes or []
return self._impl.H2d(src_offsets, dst_offsets, sizes)

def d2h(
self,
src_offsets: List[int] = None,
dst_offsets: List[int] = None,
sizes: List[int] = None,
) -> Any:
"""Triggers Device-to-Host (D2H) copy of Device memory to Host buffer."""
src_offsets = src_offsets or []
dst_offsets = dst_offsets or []
sizes = sizes or []
return self._impl.D2h(src_offsets, dst_offsets, sizes)
5 changes: 2 additions & 3 deletions tpu_raiden/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,14 @@ cc_library(
":raw_transfer_core",
":tpu_utils",
"//tpu_raiden/kv_cache:kv_cache_manager_base",
"//tpu_raiden/transport:block_transport",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@xla//xla:future",
"@com_google_absl//absl/time",
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/tsl/platform:errors",
],
)

Expand Down
65 changes: 65 additions & 0 deletions tpu_raiden/core/kv_cache_manager_with_transfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/tsl/platform/errors.h"
#include "tpu_raiden/core/host_memory_allocator.h"
#include "tpu_raiden/core/raw_transfer_core.h"
#include "tpu_raiden/core/tpu_utils.h"
Expand All @@ -77,6 +80,8 @@
namespace tpu_raiden {
namespace {

constexpr absl::Duration kPendingWorkTimeout = absl::Seconds(30);

[[noreturn]] void ThrowStatus(absl::string_view context,
const absl::Status& status) {
throw std::runtime_error(absl::StrCat(context, ": ", status.message()));
Expand Down Expand Up @@ -520,6 +525,45 @@ int64_t KVCacheManagerWithTransfer::NotifyForRead(
return static_cast<int64_t>(uuid);
}

absl::Status KVCacheManagerWithTransfer::RegisterActivePlan(
uint64_t uuid, const kv_cache::StartTransferRequest& request,
bool is_sender) {
// 1. Call base class implementation to register the plan in active_plans_
TF_RETURN_IF_ERROR(kv_cache::KVCacheManagerBase::RegisterActivePlan(
uuid, request, is_sender));

// 2. If we are the receiver and the destination memory type is HBM,
// populate active_recv_entries_ to enable automatic H2D copy!
if (!is_sender && request.dst_mem_type() == kv_cache::MEMORY_TYPE_HBM) {
std::lock_guard<std::mutex> lock(mu_);
RecvEntry recv_entry;
std::string req_id = absl::StrCat("resharded_transfer_", uuid);
recv_entry.req_id = req_id;

int64_t total_blocks = 0;
for (const auto& [src_replica_idx, schedule] :
request.shard_push_schedules()) {
std::set<int> unique_blocks_from_this_source;
for (const auto& push_entry : schedule.entries()) {
recv_entry.host_to_chip[push_entry.dst_block_id()] =
push_entry.dst_block_id();
unique_blocks_from_this_source.insert(push_entry.dst_block_id());
}
total_blocks += unique_blocks_from_this_source.size();
}
recv_entry.total_blocks = total_blocks;
if (total_blocks > 0) {
active_recv_entries_[uuid] = std::move(recv_entry);
LOG(INFO) << "RegisterActivePlan (Receiver): Populated "
"active_recv_entries_ for UUID "
<< uuid << " with " << total_blocks
<< " total physical block-pushes (including duplicates across "
"sources) for automatic H2D.";
}
}
return absl::OkStatus();
}

void KVCacheManagerWithTransfer::StartRead(
absl::string_view req_id, uint64_t uuid, absl::string_view remote_endpoint,
const std::vector<int64_t>& remote_block_ids,
Expand Down Expand Up @@ -1080,6 +1124,27 @@ absl::Status KVCacheManagerWithTransfer::OnBlocksReceived(
return absl::OkStatus();
}

absl::Status KVCacheManagerWithTransfer::WaitForPendingWork() {
LOG(INFO) << "Waiting for pending H2D transfers to complete...";
const absl::Time start = absl::Now();
while (true) {
{
std::lock_guard<std::mutex> lock(mu_);
if (active_recv_entries_.empty()) {
break;
}
const absl::Duration elapsed = absl::Now() - start;
if (elapsed > kPendingWorkTimeout) {
return absl::DeadlineExceededError(
"Timeout waiting for pending H2D transfers");
}
}
absl::SleepFor(absl::Milliseconds(100));
}
LOG(INFO) << "All pending H2D transfers completed.";
return absl::OkStatus();
}

std::string KVCacheManagerWithTransfer::EndpointWithPort(
absl::string_view endpoint, int port) const {
auto [host, ignored_port] = SplitEndpoint(endpoint);
Expand Down
6 changes: 6 additions & 0 deletions tpu_raiden/core/kv_cache_manager_with_transfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ class KVCacheManagerWithTransfer : public kv_cache::KVCacheManagerBase {
std::vector<std::string>>
CompleteReadRaw();

absl::Status RegisterActivePlan(uint64_t uuid,
const kv_cache::StartTransferRequest& request,
bool is_sender) override;

int local_control_port() const { return local_control_port_; }
int64_t node_id() const { return node_id_; }

Expand Down Expand Up @@ -266,6 +270,8 @@ class KVCacheManagerWithTransfer : public kv_cache::KVCacheManagerBase {
absl::Status OnBlocksReceived(const std::vector<int>& block_ids,
uint64_t uuid = 0) override;

absl::Status WaitForPendingWork() override;

struct RecvEntry {
std::string req_id;
int64_t total_blocks = 0;
Expand Down
6 changes: 6 additions & 0 deletions tpu_raiden/frameworks/torch/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ cc_library(
"//tpu_raiden/core:xla_raw_transfer_headers",
"@torch_tpu//shims/torch:aten_headers",
"@torch_tpu//shims/torch:torch_headers",
"@torch_tpu//shims/torch/torch/headeronly:torch_headeronly",
"@torch_tpu//torch_tpu/eager:structured_log_buffer",
],
)

Expand All @@ -89,6 +91,8 @@ cc_library(
deps = [
":torch_tpu_utils",
"//tpu_raiden/core:xla_raw_transfer_headers",
"@torch_tpu//shims/torch:aten_headers",
"@torch_tpu//shims/torch:torch_headers",
"@torch_tpu//torch_tpu/eager:tensor_to_buffer",
],
)
Expand Down Expand Up @@ -131,6 +135,7 @@ cc_library(
"//tpu_raiden/core:kv_cache_manager_with_transfer",
"//tpu_raiden/core:utils",
"//tpu_raiden/core:xla_raw_transfer_headers",
"//tpu_raiden/kv_cache:kv_cache_listener",
"@torch_tpu//shims/torch:aten_headers",
"@torch_tpu//torch_tpu/eager:tensor_to_buffer",
],
Expand Down Expand Up @@ -289,6 +294,7 @@ cc_library(
"//tpu_raiden/core:kv_cache_manager_with_transfer",
"//tpu_raiden/core:utils",
"//tpu_raiden/core:xla_raw_transfer_headers",
"//tpu_raiden/kv_cache:kv_cache_listener",
"@torch_tpu//shims/torch:aten_headers",
"@torch_tpu//shims/torch:torch_headers",
"@torch_tpu//torch_tpu/eager:tensor_to_buffer",
Expand Down
27 changes: 25 additions & 2 deletions tpu_raiden/frameworks/torch/kv_cache_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "tpu_raiden/frameworks/torch/kv_cache_manager.h"

#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include <vector>
Expand All @@ -24,6 +25,7 @@
#include "tpu_raiden/core/kv_cache_manager_with_transfer.h"
#include "tpu_raiden/core/utils.h"
#include "tpu_raiden/frameworks/torch/torch_utils.h"
#include "tpu_raiden/kv_cache/kv_cache_listener.h"

namespace tpu_raiden {
namespace torch {
Expand Down Expand Up @@ -89,15 +91,36 @@ KVCacheManager::KVCacheManager(UnpackedLayers unpacked,
KVCacheManager::KVCacheManager(const std::vector<at::Tensor>& kv_caches,
int64_t node_id, int64_t local_control_port,
int64_t max_blocks, int64_t num_slots,
double timeout_s, bool unsafe_skip_buffer_lock)
double timeout_s, bool unsafe_skip_buffer_lock,
std::optional<int> listener_port)
: KVCacheManager(UnpackLayers(SingleShardLayers(kv_caches)),
/*local_port=*/std::nullopt,
/*host_blocks_to_allocate=*/std::nullopt,
unsafe_skip_buffer_lock, /*parallelism=*/1, node_id,
local_control_port, max_blocks, num_slots, timeout_s,
kv_caches) {}
kv_caches) {
if (listener_port) {
listener_ =
std::make_unique<tpu_raiden::kv_cache::KVCacheListener>(
this, *listener_port);
}
}

KVCacheManager::~KVCacheManager() = default;

std::optional<int> KVCacheManager::listener_port() const {
if (listener_) {
return listener_->listener_port();
}
return std::nullopt;
}

bool KVCacheManager::is_listener_active() const {
if (listener_) {
return listener_->is_active();
}
return false;
}

} // namespace torch
} // namespace tpu_raiden
23 changes: 17 additions & 6 deletions tpu_raiden/frameworks/torch/kv_cache_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#define THIRD_PARTY_TPU_RAIDEN_TPU_RAIDEN_API_TORCH_KV_CACHE_MANAGER_H_

#include <cstdint>
#include <memory>
#include <optional>
#include <vector>

Expand All @@ -24,24 +25,30 @@
#include "xla/pjrt/pjrt_client.h"
#include "tpu_raiden/core/kv_cache_manager_with_transfer.h"

namespace tpu_raiden {
namespace kv_cache {
class KVCacheListener;
} // namespace kv_cache
} // namespace tpu_raiden

namespace tpu_raiden {
namespace torch {

class KVCacheManager : public KVCacheManagerWithTransfer {
public:
// PyTorch sharded constructor E2E (cache-only by default)
KVCacheManager(
const std::vector<std::vector<at::Tensor>>& device_tensors,
std::optional<int> local_port = std::nullopt,
std::optional<int> host_blocks_to_allocate = std::nullopt,
bool unsafe_skip_buffer_lock = false, int parallelism = 1);
KVCacheManager(const std::vector<std::vector<at::Tensor>>& device_tensors,
std::optional<int> local_port = std::nullopt,
std::optional<int> host_blocks_to_allocate = std::nullopt,
bool unsafe_skip_buffer_lock = false, int parallelism = 1);

// New transfer-enabled constructor (flat list of tensors, single shard per
// layer)
KVCacheManager(const std::vector<at::Tensor>& kv_caches, int64_t node_id,
int64_t local_control_port, int64_t max_blocks,
int64_t num_slots, double timeout_s,
bool unsafe_skip_buffer_lock);
bool unsafe_skip_buffer_lock,
std::optional<int> listener_port = std::nullopt);

~KVCacheManager() override;

Expand All @@ -53,6 +60,9 @@ class KVCacheManager : public KVCacheManagerWithTransfer {

const std::vector<at::Tensor>& kv_caches() const { return kv_caches_; }

std::optional<int> listener_port() const;
bool is_listener_active() const;

private:
// Buffers unpacked from a 2D tensor list, together with the owning
// DeviceBufferRefs that must outlive their use (see UnpackTorchTensor).
Expand All @@ -78,6 +88,7 @@ class KVCacheManager : public KVCacheManagerWithTransfer {
std::vector<at::Tensor> kv_caches_;
// Keep-alives for the materialized device buffers backing the manager.
std::vector<torch_tpu::DeviceBufferRef> buffer_refs_;
std::unique_ptr<tpu_raiden::kv_cache::KVCacheListener> listener_;
};

} // namespace torch
Expand Down
Loading