Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ endif()
# Framework core sources (*.cc), excluding cpu kernels (they are built separately)
file(GLOB_RECURSE SRC ${PROJECT_SOURCE_DIR}/infini_train/src/*.cc)
list(FILTER SRC EXCLUDE REGEX ".*kernels/cpu/.*")
if(NOT USE_NCCL)
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/ccl/cuda/.*")
endif()

# CPU kernels (*.cc)
file(GLOB_RECURSE CPU_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/kernels/cpu/*.cc)
Expand Down
15 changes: 8 additions & 7 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "glog/logging.h"

#include "infini_train/include/autocast.h"
#include "infini_train/include/core/device_guard.h"
#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/device.h"
#include "infini_train/include/nn/modules/loss.h"
Expand Down Expand Up @@ -140,24 +140,25 @@ void Train(const nn::parallel::Rank &rank) {

if (rank.IsParallel()) {
device = Device(Device::DeviceType::kCUDA, rank.thread_rank());
auto *pg_factory = ProcessGroupFactory::Instance(device.type());

if (ddp_world_size > 1) {
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
GetDataParallelGroupRanks(rank.GlobalRank()));
ddp_pg = pg_factory->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
GetDataParallelGroupRanks(rank.GlobalRank()));
ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank());
}

if (tp_world_size > 1) {
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
GetTensorParallelGroupRanks(rank.GlobalRank()));
tp_pg = pg_factory->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
GetTensorParallelGroupRanks(rank.GlobalRank()));
tp_rank = tp_pg->GetGroupRank(rank.GlobalRank());
// NOTE(zbl): Reserved for VocabParallelEmbedding
nn::parallel::tp_rank = tp_rank;
}

if (pp_world_size > 1) {
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
GetPipelineParallelGroupRanks(rank.GlobalRank()));
pp_pg = pg_factory->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
GetPipelineParallelGroupRanks(rank.GlobalRank()));
pp_rank = pp_pg->GetGroupRank(rank.GlobalRank());

nn::parallel::pp_rank = pp_rank;
Expand Down
4 changes: 2 additions & 2 deletions example/gpt2/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ GPT2FirstStage::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>>
auto sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled();
int tp_rank = 0;
if (tp_world_size > 1) {
auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get(
nn::parallel::GetTensorParallelProcessGroupName(device.Rank().GlobalRank()));
auto tp_group = nn::parallel::ProcessGroupFactory::Instance(device.type())
->Get(nn::parallel::GetTensorParallelProcessGroupName(device.Rank().GlobalRank()));
tp_rank = tp_group->GetGroupRank(device.Rank().GlobalRank());
}
int64_t t_local = sequence_parallel_enabled ? x1->Dims()[1] / tp_world_size : x1->Dims()[1];
Expand Down
15 changes: 8 additions & 7 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "glog/logging.h"

#include "infini_train/include/autocast.h"
#include "infini_train/include/core/device_guard.h"
#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/device.h"
#include "infini_train/include/nn/modules/loss.h"
Expand Down Expand Up @@ -121,24 +121,25 @@ void Train(const nn::parallel::Rank &rank) {

if (rank.IsParallel()) {
device = Device(Device::DeviceType::kCUDA, rank.thread_rank());
auto *pg_factory = ProcessGroupFactory::Instance(device.type());

if (ddp_world_size > 1) {
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
GetDataParallelGroupRanks(rank.GlobalRank()));
ddp_pg = pg_factory->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
GetDataParallelGroupRanks(rank.GlobalRank()));
ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank());
}

if (tp_world_size > 1) {
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
GetTensorParallelGroupRanks(rank.GlobalRank()));
tp_pg = pg_factory->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
GetTensorParallelGroupRanks(rank.GlobalRank()));
tp_rank = tp_pg->GetGroupRank(rank.GlobalRank());
// NOTE(zbl): Reserved for VocabParallelEmbedding
nn::parallel::tp_rank = tp_rank;
}

if (pp_world_size > 1) {
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
GetPipelineParallelGroupRanks(rank.GlobalRank()));
pp_pg = pg_factory->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
GetPipelineParallelGroupRanks(rank.GlobalRank()));
pp_rank = pp_pg->GetGroupRank(rank.GlobalRank());

nn::parallel::pp_rank = pp_rank;
Expand Down
11 changes: 0 additions & 11 deletions infini_train/include/core/blas_handle.h

This file was deleted.

100 changes: 100 additions & 0 deletions infini_train/include/core/ccl/ccl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#pragma once

#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <unordered_map>

#include "infini_train/include/core/ccl/ccl_common.h"
#include "infini_train/include/datatype.h"
#include "infini_train/include/device.h"
#include "infini_train/include/nn/parallel/reduce_op_type.h"

namespace infini_train::core {

class Stream;

class CclImpl {
public:
CclImpl() {}
virtual ~CclImpl() = default;

virtual Device::DeviceType Type() const = 0;

virtual void GroupStart() const;

virtual void GroupEnd() const;

virtual void GetAsyncError(const CclComm *comm, CclStatus *async_error) const;

virtual void GetUniqueId(CclUniqueId **unique_id) const;

virtual void CommInitAll(CclComm **comms, int ndev, const int *devlist) const;

virtual void CommInitRank(CclComm **comm, int nranks, const CclUniqueId &unique_id, int rank) const;

virtual void CommDestroy(CclComm *comm) const;

virtual void AllReduce(const void *sendbuff, void *recvbuff, size_t count, DataType dtype,
nn::parallel::function::ReduceOpType reduce_op, const CclComm *comm, Stream *stream) const;

virtual void Broadcast(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, int root,
const CclComm *comm, Stream *stream) const;

virtual void Reduce(const void *sendbuff, void *recvbuff, size_t count, DataType dtype,
nn::parallel::function::ReduceOpType reduce_op, int root, const CclComm *comm,
Stream *stream) const;

virtual void AllGather(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm,
Stream *stream) const;

virtual void ReduceScatter(const void *sendbuff, void *recvbuff, size_t recv_count, DataType dtype,
nn::parallel::function::ReduceOpType reduce_op, const CclComm *comm,
Stream *stream) const;

virtual void Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm,
Stream *stream) const;

virtual void Recv(void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, Stream *stream) const;
};

class CclGroupGuard {
public:
explicit CclGroupGuard(Device::DeviceType type);
~CclGroupGuard();

CclGroupGuard(const CclGroupGuard &) = delete;
CclGroupGuard &operator=(const CclGroupGuard &) = delete;
CclGroupGuard(CclGroupGuard &&) = delete;
CclGroupGuard &operator=(CclGroupGuard &&) = delete;

private:
CclImpl *impl_ = nullptr;
};

class CclImplRegistry {
public:
static CclImplRegistry &Instance();

void Register(Device::DeviceType type, std::unique_ptr<CclImpl> impl);

CclImpl *Get(Device::DeviceType type) const;

private:
CclImplRegistry() = default;
CclImplRegistry(const CclImplRegistry &) = delete;
CclImplRegistry &operator=(const CclImplRegistry &) = delete;

std::unordered_map<Device::DeviceType, std::unique_ptr<CclImpl>> impls_;
};

CclImpl *GetCclImpl(Device::DeviceType type);

} // namespace infini_train::core

#define INFINI_TRAIN_REGISTER_CCL_IMPL(device_type, class_impl) \
static const bool __infini_train_ccl_registered##__COUNTER__ = []() { \
infini_train::core::CclImplRegistry::Instance().Register(device_type, std::make_unique<class_impl>()); \
return true; \
}();
58 changes: 58 additions & 0 deletions infini_train/include/core/ccl/ccl_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#pragma once

#include <cstddef>
#include <cstdint>

#include "glog/logging.h"

namespace infini_train::core {

#define INFINI_TRAIN_CCL_STATUS_LIST(X) \
X(kSuccess, 0) \
X(kInProgress, 1) \
X(kTimeout, 2) \
X(kError, -1) \
X(kInvalidArgument, -2) \
X(kUnavailable, -3) \
X(kNotSupported, -4) \
X(kInternal, -5) \
X(kUnknown, -127)

enum class CclStatus : int32_t {
#define INFINI_TRAIN_CCL_STATUS_ENUM_ITEM(name, value) name = value,
INFINI_TRAIN_CCL_STATUS_LIST(INFINI_TRAIN_CCL_STATUS_ENUM_ITEM)
#undef INFINI_TRAIN_CCL_STATUS_ENUM_ITEM
};

inline const char *CclStatusToString(CclStatus status) {
switch (status) {
#define INFINI_TRAIN_CCL_STATUS_CASE(name, value) \
case CclStatus::name: \
return #name;
INFINI_TRAIN_CCL_STATUS_LIST(INFINI_TRAIN_CCL_STATUS_CASE)
#undef INFINI_TRAIN_CCL_STATUS_CASE
default:
LOG(FATAL) << "Unsupported RuntimeStatus type: " << static_cast<int>(status);
return "";
}
}

#undef INFINI_TRAIN_CCL_STATUS_LIST

class CclComm {
public:
CclComm() = default;
virtual ~CclComm() = default;
};

class CclUniqueId {
public:
CclUniqueId() = default;
virtual ~CclUniqueId() = default;

virtual size_t Size() const = 0;
virtual const void *Data() const = 0;
virtual void Load(const void *src, size_t size) = 0;
};

} // namespace infini_train::core
15 changes: 15 additions & 0 deletions infini_train/include/core/ccl/ccl_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include <string>

#include "infini_train/include/core/ccl/ccl_common.h"

namespace infini_train::core {

void WriteUniqueIdFile(const CclUniqueId &unique_id, const std::string &pg_name);

void ReadUniqueIdFile(CclUniqueId *unique_id, const std::string &pg_name);

void CleanupUniqueIdFile(const std::string &pg_name);

} // namespace infini_train::core
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <memory>
#include <unordered_map>

#include "infini_train/include/core/runtime/runtime_common.h"
#include "infini_train/include/device.h"

namespace infini_train::core {
Expand Down Expand Up @@ -55,8 +56,6 @@ inline const char *MemcpyKindToString(MemcpyKind k) {
// DeviceGuard (the public RAII wrapper) forwards calls to the DeviceGuardImpl
// instance registered for the device type.
//
// TODO(dcj): add event management
//
class DeviceGuardImpl {
public:
DeviceGuardImpl() {}
Expand All @@ -81,6 +80,34 @@ class DeviceGuardImpl {

virtual Stream *GetStream(Device) const;

virtual Stream *CreateStream(Device) const;

virtual Stream *CreateStreamWithPriority(Device, int priority) const;

virtual void DestroyStream(Stream *) const;

virtual void GetStreamPriorityRange(int *low, int *high) const;

// ----------------------------------------------------------------------
// Event management
// ----------------------------------------------------------------------

virtual void EventCreate(Event **event) const;

virtual void EventCreateWithFlags(Event **event, EventFlag flags) const;

virtual void EventDestroy(Event *event) const;

virtual void EventRecord(Event *event, Stream *stream) const;

virtual void StreamWaitEvent(Stream *stream, Event *event, uint32_t flags) const;

virtual RuntimeStatus EventSynchronize(Event *event) const;

virtual RuntimeStatus EventQuery(Event *event) const;

virtual float EventElapsedTime(Event *start_event, Event *stop_event) const;

// ----------------------------------------------------------------------
// Synchronization
// ----------------------------------------------------------------------
Expand Down
Loading