Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ set(INFINIOPS_OPS "" CACHE STRING
set(INFINIOPS_SMOKE_BUILD OFF CACHE BOOL
"Build only the smoke-test operator subset")
set(_infiniops_smoke_ops
add mul cast cat gemm matmul linear rms_norm swiglu causal_softmax abs clamp exp)
add mul cast cat gemm matmul linear rms_norm embedding swiglu causal_softmax abs clamp exp)
set(_infiniops_smoke_torch_ops abs clamp exp)

if(INFINIOPS_SMOKE_BUILD)
Expand Down
96 changes: 96 additions & 0 deletions src/base/embedding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#ifndef INFINI_OPS_BASE_EMBEDDING_H_
#define INFINI_OPS_BASE_EMBEDDING_H_

#include <cstddef>

#include "common/assert_utils.h"
#include "data_type.h"
#include "operator.h"
#include "tensor.h"

namespace infini::ops {

// Aligned with InfiniCore and `torch.nn.functional.embedding`.
class Embedding : public Operator<Embedding> {
public:
Embedding(const Tensor input, const Tensor weight, Tensor out)
: input_shape_{input.shape()},
weight_shape_{weight.shape()},
out_shape_{out.shape()},
input_strides_{input.strides()},
weight_strides_{weight.strides()},
out_strides_{out.strides()},
input_dtype_{input.dtype()},
weight_dtype_{weight.dtype()},
out_dtype_{out.dtype()},
num_indices_{NumIndices(input_shape_)},
vocab_size_{weight.size(0)},
embedding_dim_{weight.size(1)} {
INFINI_OPS_ASSERT(weight.ndim() == 2, "`Embedding` requires 2D `weight`");
INFINI_OPS_ASSERT(out.ndim() == input.ndim() + 1,
"`Embedding` output rank must be input rank + 1");

for (Tensor::Size i = 0; i < input.ndim(); ++i) {
INFINI_OPS_ASSERT(
out.size(i) == input.size(i),
"`Embedding` output shape must match `input` shape on non-last "
"dims");
}

INFINI_OPS_ASSERT(
out.size(-1) == embedding_dim_,
"`Embedding` output last dim must equal `weight` embedding dim");
INFINI_OPS_ASSERT(
input_dtype_ == DataType::kInt32 || input_dtype_ == DataType::kInt64,
"`Embedding` supports int32 and int64 indices only");
INFINI_OPS_ASSERT(weight_dtype_ == DataType::kFloat32 ||
weight_dtype_ == DataType::kFloat16 ||
weight_dtype_ == DataType::kBFloat16,
"`Embedding` supports float32, float16, and bfloat16 "
"weights only");
INFINI_OPS_ASSERT(out_dtype_ == weight_dtype_,
"`Embedding` output dtype must match `weight` dtype");
}

virtual void operator()(const Tensor input, const Tensor weight,
Tensor out) const = 0;

protected:
static Tensor::Size NumIndices(const Tensor::Shape& input_shape) {
Tensor::Size num_indices = 1;

for (Tensor::Size dim : input_shape) {
num_indices *= dim;
}

return num_indices;
}

Tensor::Shape input_shape_;

Tensor::Shape weight_shape_;

Tensor::Shape out_shape_;

Tensor::Strides input_strides_;

Tensor::Strides weight_strides_;

Tensor::Strides out_strides_;

DataType input_dtype_;

DataType weight_dtype_;

DataType out_dtype_;

Tensor::Size num_indices_{0};

Tensor::Size vocab_size_{0};

Tensor::Size embedding_dim_{0};
};

} // namespace infini::ops

#endif
28 changes: 28 additions & 0 deletions src/common/assert_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef INFINI_OPS_COMMON_ASSERT_UTILS_H_
#define INFINI_OPS_COMMON_ASSERT_UTILS_H_

#include <cassert>

namespace infini::ops::detail {

// Iluvatar/CoreX clang does not accept `(message) " suffix"` string
// concatenation inside macros, so pass `message` as a normal `const char*`.
inline void InfiniOpsAssertFail(const char* message, const char* file, int line,
const char* func) {
(void)file;
(void)line;
(void)func;
assert(false && message);
}

} // namespace infini::ops::detail

#define INFINI_OPS_ASSERT(condition, message) \
do { \
if (!(condition)) { \
::infini::ops::detail::InfiniOpsAssertFail((message), __FILE__, \
__LINE__, __func__); \
} \
} while (0)

#endif
21 changes: 21 additions & 0 deletions src/native/cuda/iluvatar/ops/embedding/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef INFINI_OPS_ILUVATAR_EMBEDDING_KERNEL_H_
#define INFINI_OPS_ILUVATAR_EMBEDDING_KERNEL_H_

#include <utility>

#include "native/cuda/iluvatar/caster.cuh"
#include "native/cuda/iluvatar/runtime_.h"
#include "native/cuda/ops/embedding/kernel.h"

namespace infini::ops {

template <>
class Operator<Embedding, Device::Type::kIluvatar>
: public CudaEmbedding<Runtime<Device::Type::kIluvatar>> {
public:
using CudaEmbedding<Runtime<Device::Type::kIluvatar>>::CudaEmbedding;
};

} // namespace infini::ops

#endif
21 changes: 21 additions & 0 deletions src/native/cuda/metax/ops/embedding/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef INFINI_OPS_METAX_EMBEDDING_KERNEL_H_
#define INFINI_OPS_METAX_EMBEDDING_KERNEL_H_

#include <utility>

#include "native/cuda/metax/caster.cuh"
#include "native/cuda/metax/runtime_.h"
#include "native/cuda/ops/embedding/kernel.h"

namespace infini::ops {

template <>
class Operator<Embedding, Device::Type::kMetax>
: public CudaEmbedding<Runtime<Device::Type::kMetax>> {
public:
using CudaEmbedding<Runtime<Device::Type::kMetax>>::CudaEmbedding;
};

} // namespace infini::ops

#endif
25 changes: 25 additions & 0 deletions src/native/cuda/moore/ops/embedding/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef INFINI_OPS_MOORE_EMBEDDING_KERNEL_H_
#define INFINI_OPS_MOORE_EMBEDDING_KERNEL_H_

#include <utility>

// clang-format off
#include "native/cuda/moore/polyfills.cuh"
// clang-format on

#include "native/cuda/moore/caster.cuh"
#include "native/cuda/moore/runtime_.h"
#include "native/cuda/ops/embedding/kernel.h"

namespace infini::ops {

template <>
class Operator<Embedding, Device::Type::kMoore>
: public CudaEmbedding<Runtime<Device::Type::kMoore>> {
public:
using CudaEmbedding<Runtime<Device::Type::kMoore>>::CudaEmbedding;
};

} // namespace infini::ops

#endif
21 changes: 21 additions & 0 deletions src/native/cuda/nvidia/ops/embedding/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef INFINI_OPS_NVIDIA_EMBEDDING_KERNEL_H_
#define INFINI_OPS_NVIDIA_EMBEDDING_KERNEL_H_

#include <utility>

#include "native/cuda/nvidia/caster.cuh"
#include "native/cuda/nvidia/runtime_.h"
#include "native/cuda/ops/embedding/kernel.h"

namespace infini::ops {

template <>
class Operator<Embedding, Device::Type::kNvidia>
: public CudaEmbedding<Runtime<Device::Type::kNvidia>> {
public:
using CudaEmbedding<Runtime<Device::Type::kNvidia>>::CudaEmbedding;
};

} // namespace infini::ops

#endif
Loading
Loading