What is one-hot?

SEO:
libtorch 如何 OneHot ?
torch OneHot 源代码 ?

https://www.tensorflow.org/api_docs/python/tf/one_hot

最新的 1.3 版本中已经添加了该函数

#include <torch/torch.h>
#include <c10/util/StringUtil.h>
torch::Tensor one_hot(const torch::Tensor &self, int64_t num_classes) {
	AT_CHECK(self.dtype() == torch::kLong, "one_hot is only applicable to index tensor.");
	auto shape = self.sizes().vec();

	// empty tensor could be converted to one hot representation,
	// but shape inference is not possible.
	if (self.numel() == 0) {
		if (num_classes <= 0) {
			AT_ERROR("Can not infer total number of classes from empty tensor.");
		}
		else {
			shape.push_back(num_classes);
			return at::empty(shape, self.options());
		}
	}

	// non-empty tensor
	AT_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
	if (num_classes == -1) {
		num_classes = self.max().item().toLong() + 1;
	}
	else {
		AT_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
	}

	shape.push_back(num_classes);
	torch::Tensor ret = at::zeros(shape, self.options());
	ret.scatter_(-1, self.unsqueeze(-1), 1);
	return ret;
}

使用示例

	torch::TensorOptions options(torch::kLong);
	auto tensor = torch::tensor({ 0,1,2 }, options);
	std::cout << tensor << std::endl;

	try
	{
		auto one_hot = torch::one_hot(tensor,4);
		std::cout << one_hot << std::endl;
	}
	catch (const c10::Error& watch)
	{
		std::cout << watch.msg() << std::endl;
	}
原文地址:https://www.cnblogs.com/cheungxiongwei/p/11872446.html