tensorflow源码阅读-op注册
一、OP注册
//tensorflow/core/platform/maxros.h
// A macro to disallow the copy constructor and operator= functions
// This is usually placed in the private: declarations for a class.
#define TF_DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&) = delete; \
void operator=(const TypeName&) = delete
不允许复制构造函数和操作
//tensorflow/core/platform/maxros.h
#define TF_ATTRIBUTE_UNUSED __attribute__((unused))
表示该函数或变量可能不使用,这个属性可以避免编译器产生警告信息
// tensorflow/core/framework/selective_registration.h
#if (!defined(SHOULD_REGISTER_OP) || !defined(SHOULD_REGISTER_OP_GRADIENT) || \
!defined(SHOULD_REGISTER_OP_KERNEL))
static_assert(false, "ops_to_register.h must define SHOULD_REGISTER macros");
#endif
#else
#define SHOULD_REGISTER_OP(op) true
#define SHOULD_REGISTER_OP_GRADIENT true
#define SHOULD_REGISTER_OP_KERNEL(clz) true
#endif
注册op需要定义SHOULD_REGISTER_OP(op) 为真
// tensorflow/core/framework/op.h
#define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
#define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
#define REGISTER_OP_UNIQ(ctr, name) \
static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \
TF_ATTRIBUTE_UNUSED = \
::tensorflow::register_op::OpDefBuilderWrapper<SHOULD_REGISTER_OP( \
name)>(name)
op注册宏定义,用户注册矩阵相乘的op如下:
//tensorflow/core/ops/math_ops.cc
REGISTER_OP("MatMul")
.Input("a: T")
.Input("b: T")
.Output("product: T")
.Attr("transpose_a: bool = false")
.Attr("transpose_b: bool = false")
.Attr("T: {bfloat16, half, float, double, int32, complex64, complex128}")
.SetShapeFn(shape_inference::MatMulShape);
REGISTER_OP宏定义首先调用OpDefBuilderWrapper这个模板类OpDefBuilderWrapper<true>("MatMul")
template <>
class OpDefBuilderWrapper<true> {
public:
OpDefBuilderWrapper(const char name[]) : builder_(name) {}
OpDefBuilderWrapper<true>& Attr(StringPiece spec) {
builder_.Attr(spec);
return *this;
}
OpDefBuilderWrapper<true>& Input(StringPiece spec) {
builder_.Input(spec);
return *this;
}
OpDefBuilderWrapper<true>& Output(StringPiece spec) {
builder_.Output(spec);
return *this;
}
OpDefBuilderWrapper<true>& SetIsCommutative() {
builder_.SetIsCommutative();
return *this;
}
OpDefBuilderWrapper<true>& SetIsAggregate() {
builder_.SetIsAggregate();
return *this;
}
OpDefBuilderWrapper<true>& SetIsStateful() {
builder_.SetIsStateful();
return *this;
}
OpDefBuilderWrapper<true>& SetAllowsUninitializedInput() {
builder_.SetAllowsUninitializedInput();
return *this;
}
OpDefBuilderWrapper<true>& Deprecated(int version, StringPiece explanation) {
builder_.Deprecated(version, explanation);
return *this;
}
OpDefBuilderWrapper<true>& Doc(StringPiece text) {
builder_.Doc(text);
return *this;
}
OpDefBuilderWrapper<true>& SetShapeFn(
Status (*fn)(shape_inference::InferenceContext*)) {
builder_.SetShapeFn(fn);
return *this;
}
const ::tensorflow::OpDefBuilder& builder() const { return builder_; }
private:
mutable ::tensorflow::OpDefBuilder builder_;
};
OpDefBuilderWrapper类包含一个OpDefBuilder 对象
mutable ::tensorflow::OpDefBuilder builder_;
当执行REGISTER_OP(“MatMul”)时,OpDefBuilderWrapper的构造函数
OpDefBuilderWrapper(const char name[]) : builder_(name) {}将进行初始化操作,调用OpDefBuilder的构造函数 explicit OpDefBuilder(StringPiece op_name);并返回对象本身。
也就是说REGISTER_OP(“MatMul”)也是一个OpDefBuilderWrapper对象,里面包含一个OpDefBuilder 对象。这样就可以连续设置op的属性:
REGISTER_OP(“MatMul”)
.Input(“a: T”)
.Input(“b: T”)…
把OpDefBuilder包装成OpDefBuilderWrapper是为了链式设置op属性
OpDefBuilder 将对应的op属性保存到模板库里面(std::vector):
//tensorflow/core/framework/op_def_builder.h
private:
OpDef* op_def() { return &op_reg_data_.op_def; }
OpRegistrationData op_reg_data_;
std::vector<string> attrs_;
std::vector<string> inputs_;
std::vector<string> outputs_;
string doc_;
std::vector<string> errors_;
};
当所有属性都设置完之后,其本身是一个OpDefBuilderWrapper对象,传到OpDefBuilderReceiver里,具体如下。
// Other registration ---------------------------------------------------------
// tensorflow/core/framework/op.cc
namespace register_op {
OpDefBuilderReceiver::OpDefBuilderReceiver(const OpDefBuilderWrapper<true>& wrapper)
{
OpRegistry::Global()->Register(
[wrapper](OpRegistrationData* op_reg_data) -> Status {
return wrapper.builder().Finalize(op_reg_data);
});
}
} // namespace register_op
OpDefBuilderReceiver接收到wrapper交给 OpRegistry::Global()->Register进行注册,Register()里面包含一个lanmda表达式,捕获到wrapper这个对象,wrapper.builder()是调用OpDefBuilderWrapper里面的:
const ::tensorflow::OpDefBuilder& builder() const { return builder_; }
把包含注册信息的builder_对象返回,调用Finalize(op_reg_data),进行解析,把builder_里的信息解析出来放到OpRegistrationData这个结构体里面,返回Statu状态
OpRegistry继承OpRegistryInterface,里面有一个静态函数(单例模式)
static OpRegistry* Global();
具体实现如下:
// tensorflow/core/framework/op.cc
// static
OpRegistry* OpRegistry::Global() {
static OpRegistry* global_op_registry = new OpRegistry;
return global_op_registry;
}
它创建一个OpRegistry对象并返回,OpRegistry::Global()->Register()调用Register()方法
void Register(const OpRegistrationDataFactory& op_data_factory);
这里的OpRegistrationDataFactory进行了包装
typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;
function可以将普通函数,lambda表达式和函数对象类统一起来。它们并不是相同的类型,然而通过function模板类,可以转化为相同类型的对象(function对象),从而放入一个map里。
lambda函数的类型是std:function,因此要这样转化。
OpRegistry::Register的具体实现如下
void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) {
mutex_lock lock(mu_);
if (initialized_) {
TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
} else {
deferred_.push_back(op_data_factory);
}
}
已经有一个注册了该名称的OpDef则返回non-okay状态,否者将op的name和OpRegistrationData组成pair放进hashmap完成注册。
deferred_ 是为了延时注册
mutable std::vector deferred_ GUARDED_BY(mu_);
GUARDED_BY(mu_)(tensorflow/core/platform/default/thread_annotations.h)指的是共享变量加锁。
OpRegistrationData的解释
//tensorflow/core/framework/op_def_builder.h
typedef std::function<Status(shape_inference::InferenceContext* c)>
OpShapeInferenceFn;
struct OpRegistrationData {
public:
OpRegistrationData() {}
OpRegistrationData(const OpDef& def) : op_def(def) {}
OpRegistrationData(const OpDef& def, const OpShapeInferenceFn& fn,
bool is_function = false)
: op_def(def), shape_inference_fn(fn), is_function_op(is_function) {}
OpDef op_def;
OpShapeInferenceFn shape_inference_fn;
bool is_function_op = false;
};
OpRegistrationData是一个结构体,里面主要包含OpDef 和OpShapeInferenceFn 两个对象
其中OpDef 是tensorflow/core/framework/op_def.proto通过protobuf产生的类,里面包含着op的属性。 OpShapeInferenceFn 描述 OP 的 Shape 的推演规则