tensorflow源码阅读-op注册

一、OP注册

//tensorflow/core/platform/maxros.h
// A macro to disallow the copy constructor and operator= functions
// This is usually placed in the private: declarations for a class.
#define TF_DISALLOW_COPY_AND_ASSIGN(TypeName) \
  TypeName(const TypeName&) = delete;         \
  void operator=(const TypeName&) = delete

不允许复制构造函数和操作

//tensorflow/core/platform/maxros.h
#define TF_ATTRIBUTE_UNUSED __attribute__((unused))

表示该函数或变量可能不使用,这个属性可以避免编译器产生警告信息

// tensorflow/core/framework/selective_registration.h
#if (!defined(SHOULD_REGISTER_OP) || !defined(SHOULD_REGISTER_OP_GRADIENT) || \
     !defined(SHOULD_REGISTER_OP_KERNEL))
static_assert(false, "ops_to_register.h must define SHOULD_REGISTER macros");
#endif
#else
#define SHOULD_REGISTER_OP(op) true
#define SHOULD_REGISTER_OP_GRADIENT true
#define SHOULD_REGISTER_OP_KERNEL(clz) true
#endif

注册op需要定义SHOULD_REGISTER_OP(op) 为真

// tensorflow/core/framework/op.h
#define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
#define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
#define REGISTER_OP_UNIQ(ctr, name)                   \         
static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \
      TF_ATTRIBUTE_UNUSED =                                                 \     
      ::tensorflow::register_op::OpDefBuilderWrapper<SHOULD_REGISTER_OP( \
              name)>(name)

op注册宏定义,用户注册矩阵相乘的op如下:

//tensorflow/core/ops/math_ops.cc
REGISTER_OP("MatMul")
    .Input("a: T")
    .Input("b: T")
    .Output("product: T")
    .Attr("transpose_a: bool = false")
    .Attr("transpose_b: bool = false")
    .Attr("T: {bfloat16, half, float, double, int32, complex64, complex128}")
    .SetShapeFn(shape_inference::MatMulShape);

REGISTER_OP宏定义首先调用OpDefBuilderWrapper这个模板类OpDefBuilderWrapper<true>("MatMul")

template <>
class OpDefBuilderWrapper<true> {
 public:
  OpDefBuilderWrapper(const char name[]) : builder_(name) {}
  OpDefBuilderWrapper<true>& Attr(StringPiece spec) {
    builder_.Attr(spec);
    return *this;
  }
  OpDefBuilderWrapper<true>& Input(StringPiece spec) {
    builder_.Input(spec);
    return *this;
  }
  OpDefBuilderWrapper<true>& Output(StringPiece spec) {
    builder_.Output(spec);
    return *this;
  }
  OpDefBuilderWrapper<true>& SetIsCommutative() {
    builder_.SetIsCommutative();
    return *this;
  }
  OpDefBuilderWrapper<true>& SetIsAggregate() {
    builder_.SetIsAggregate();
    return *this;
  }
  OpDefBuilderWrapper<true>& SetIsStateful() {
    builder_.SetIsStateful();
    return *this;
  }
  OpDefBuilderWrapper<true>& SetAllowsUninitializedInput() {
    builder_.SetAllowsUninitializedInput();
    return *this;
  }
  OpDefBuilderWrapper<true>& Deprecated(int version, StringPiece explanation) {
    builder_.Deprecated(version, explanation);
    return *this;
  }
  OpDefBuilderWrapper<true>& Doc(StringPiece text) {
    builder_.Doc(text);
    return *this;
  }
  OpDefBuilderWrapper<true>& SetShapeFn(
      Status (*fn)(shape_inference::InferenceContext*)) {
    builder_.SetShapeFn(fn);
    return *this;
  }
  const ::tensorflow::OpDefBuilder& builder() const { return builder_; }

 private:
  mutable ::tensorflow::OpDefBuilder builder_;
};

OpDefBuilderWrapper类包含一个OpDefBuilder 对象
mutable ::tensorflow::OpDefBuilder builder_;
当执行REGISTER_OP(“MatMul”)时,OpDefBuilderWrapper的构造函数
OpDefBuilderWrapper(const char name[]) : builder_(name) {}将进行初始化操作,调用OpDefBuilder的构造函数 explicit OpDefBuilder(StringPiece op_name);并返回对象本身。
也就是说REGISTER_OP(“MatMul”)也是一个OpDefBuilderWrapper对象,里面包含一个OpDefBuilder 对象。这样就可以连续设置op的属性:
REGISTER_OP(“MatMul”)
.Input(“a: T”)
.Input(“b: T”)…
把OpDefBuilder包装成OpDefBuilderWrapper是为了链式设置op属性

OpDefBuilder 将对应的op属性保存到模板库里面(std::vector):

//tensorflow/core/framework/op_def_builder.h
private:
  OpDef* op_def() { return &op_reg_data_.op_def; }

  OpRegistrationData op_reg_data_;
  std::vector<string> attrs_;
  std::vector<string> inputs_;
  std::vector<string> outputs_;
  string doc_;
  std::vector<string> errors_;
};

当所有属性都设置完之后,其本身是一个OpDefBuilderWrapper对象,传到OpDefBuilderReceiver里,具体如下。

// Other registration ---------------------------------------------------------
// tensorflow/core/framework/op.cc
namespace register_op {
OpDefBuilderReceiver::OpDefBuilderReceiver(const OpDefBuilderWrapper<true>& wrapper) 
{
  OpRegistry::Global()->Register(
  		[wrapper](OpRegistrationData* op_reg_data) -> Status {
        	return wrapper.builder().Finalize(op_reg_data);
      });
}
}  // namespace register_op

OpDefBuilderReceiver接收到wrapper交给 OpRegistry::Global()->Register进行注册,Register()里面包含一个lanmda表达式,捕获到wrapper这个对象,wrapper.builder()是调用OpDefBuilderWrapper里面的:
const ::tensorflow::OpDefBuilder& builder() const { return builder_; }
把包含注册信息的builder_对象返回,调用Finalize(op_reg_data),进行解析,把builder_里的信息解析出来放到OpRegistrationData这个结构体里面,返回Statu状态

OpRegistry继承OpRegistryInterface,里面有一个静态函数(单例模式)
static OpRegistry* Global();
具体实现如下:

// tensorflow/core/framework/op.cc
// static
OpRegistry* OpRegistry::Global() {
  static OpRegistry* global_op_registry = new OpRegistry;
  return global_op_registry;
}

它创建一个OpRegistry对象并返回,OpRegistry::Global()->Register()调用Register()方法
void Register(const OpRegistrationDataFactory& op_data_factory);
这里的OpRegistrationDataFactory进行了包装
typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;
function可以将普通函数,lambda表达式和函数对象类统一起来。它们并不是相同的类型,然而通过function模板类,可以转化为相同类型的对象(function对象),从而放入一个map里。
lambda函数的类型是std:function,因此要这样转化。

OpRegistry::Register的具体实现如下

void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) {
  mutex_lock lock(mu_);
  if (initialized_) {
    TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
  } else {
    deferred_.push_back(op_data_factory);
  }
}

已经有一个注册了该名称的OpDef则返回non-okay状态,否者将op的name和OpRegistrationData组成pair放进hashmap完成注册。

deferred_ 是为了延时注册
mutable std::vector deferred_ GUARDED_BY(mu_);
GUARDED_BY(mu_)(tensorflow/core/platform/default/thread_annotations.h)指的是共享变量加锁。

OpRegistrationData的解释

//tensorflow/core/framework/op_def_builder.h
typedef std::function<Status(shape_inference::InferenceContext* c)>
    OpShapeInferenceFn;

struct OpRegistrationData {
 public:
  OpRegistrationData() {}
  OpRegistrationData(const OpDef& def) : op_def(def) {}
  OpRegistrationData(const OpDef& def, const OpShapeInferenceFn& fn,
                     bool is_function = false)
      : op_def(def), shape_inference_fn(fn), is_function_op(is_function) {}

  OpDef op_def;
  OpShapeInferenceFn shape_inference_fn;
  bool is_function_op = false;
};

OpRegistrationData是一个结构体,里面主要包含OpDef 和OpShapeInferenceFn 两个对象
其中OpDef 是tensorflow/core/framework/op_def.proto通过protobuf产生的类,里面包含着op的属性。 OpShapeInferenceFn 描述 OP 的 Shape 的推演规则


版权声明:本文为weixin_43800762原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
THE END
< <上一篇
下一篇>>