Code前端首页关于Code前端联系我们

TensorFlow核心组件系列中对Graph底层机制的探索

terry 2年前 (2023-09-23) 阅读数 69 #AI人工智能

Graph(计算图)是TensorFlow的核心组件。在TensorFlow中,Graph扮演着重要的角色,用来表示深度学习模型的计算过程和数据流。

Graph的基本机制包括TensorFlow框架的前端和后端系统,它们共同构建和执行计算图。前端系统负责构建计算图并定义模型的结构和操作,而后端系统提供运行环境并负责对计算图执行操作。通过深入研究Graph的底层机制,我们可以揭示TensorFlow框架的工作原理,为模型设计和优化提供指导。

先介绍一下Graph:

在TensorFlow中,Graph是一个由节点(Nodes)和边(Edges)组成的有向无环图。节点代表计算操作,而边代表数据流。每个节点可以接收多个输入边和输出边,形成数据传输和转换关系。这种节点和边的组织结构可以清晰地表示计算过程和模型依赖关系,方便模型优化和并行计算。

在TensorFlow前端系统中,Operations用于表示图中的节点实例,而Tensors表示图中的边实例,用于连接Op节点。 Op包含节点的计算逻辑和操作类型,而Tensor则承载数据并在节点之间传递。

在C++后端系统中,TensorFlow定义了Edges和Nodes来表示计算图。此外,Edge 和 Tensor 之间还存在同步电话连接。 Edge是Node的连接路径,用于传输数据和建立计算依赖,而Tensor是Edge上承载数据的算子。

Tensor维护TF后端系统中的指针信息和基本数据形式,并通过引用计数来管理数据生命周期。这样的设计使得TensorFlow能够实现延迟计算和复用内存,提高计算效率和资源利用率。

1。Graph在计算图表中的功能特点

首先,Graph的特点之一就是灵活性。 TensorFlow计算图允许用户自由定义模型结构和运算,以满足各种复杂的深度学习需求。通过添加、删除或修改节点和边,我们可以灵活地设计和调整计算图以适应不同的任务需求和模型架构。

其次,Graph的特点之一就是计算图的优化。 TensorFlow提供了很多优化技术,可以通过优化计算图来提高模型的性能和效率。例如,图剪枝技术可以去除无用的节点和边,减少计算和内存消耗。同时,利用并行计算技术,可以并行地对计算图进行操作,加速模型训练和推理。此外,还可以利用量化技术对计算图中的张量进行精确压缩,以降低模型的存储和计算成本。

此外,Graph还提供了模型可视化和调试的支持。 TensorFlow提供了可视化工具,可以图形化地展示计算图,帮助开发者直观地了解模型的结构和数据流。通过分析计算图,可以发现并解决潜在的问题,提高模型的稳定性和可靠性。

2。前端(Python)Future 定义

一个Future 对象将包含一系列Operation 对象,这些Operation 对象代表计算单元的集合。同时,它间接包含了一系列Tensor对象,这些对象代表了数据单元的集合。

我们先看一下Python前端Graph的定义:

class Graph(object):
       def __init__(self):
         self._lock = threading.Lock()
         self._nodes_by_id = dict()    # GUARDED_BY(self._lock)
         self._next_id_counter = 0     # GUARDED_BY(self._lock)
         self._nodes_by_name = dict()  # GUARDED_BY(self._lock)
         self._registered_ops = op_def_registry.get_registered_ops()

        def _add_op(self, op):
                self._check_not_finalized()
                if not isinstance(op, (Tensor, Operation)):
                  raise TypeError("op must be a Tensor or Operation: %s" % op)
                with self._lock:
                  # pylint: disable=protected-access
                  if op._id in self._nodes_by_id:
                    raise ValueError("cannot add an op with id %d as it already "
                                     "exists in the graph" % op._id)
                  if op.name in self._nodes_by_name:
                    raise ValueError("cannot add op with name %s as that name "
                                     "is already used" % op.name)
                  self._nodes_by_id[op._id] = op
                  self._nodes_by_name[op.name] = op
                  self._version = max(self._version, op._id)

         def add_to_collection(name, value):
                      get_default_graph().add_to_collection(name, value)

         # 替换线程默认图
               def as_default(self):
                   return _default_graph_stack.get_controller(self)

               # 栈式管理,push pop
               @tf_contextlib.contextmanager
               def get_controller(self, default):
                    try:
                      context.context_stack.push(default.building_function, default.as_default)
                    finally:
                      context.context_stack.pop()

         def create_op(
                  self,
                  op_type,
                  inputs,
                  dtypes=None,  # pylint: disable=redefined-outer-name
                  input_types=None,
                  name=None,
                  attrs=None,
                  op_def=None,
                  compute_shapes=True,
                  compute_device=True):
                        for idx, a in enumerate(inputs):
                          if not isinstance(a, Tensor):
                            raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
                        return self._create_op_internal(op_type, inputs, dtypes, input_types, name,
                                                        attrs, op_def, compute_device)

从源码中可以看出,为了快速索引图中的节点信息,为图中的每个Operation提供了一个唯一的id当前的Future作用域,以及Future中存储_nodes_by_id的数据字典。同时,为了快速通过节点名称索引节点信息,Graph中还存储了一个数据字典_nodes_by_name。

在图构建阶段,通过OP构造函数创建OP,并最终添加到当前的Graph实例中。当图被冻结时,节点无法添加到图中,因此可以在多个线程之间安全地共享Graph实例。

默认堆栈管理图

另外,从图管理中可以看出,默认图采用堆栈管理方式,通过push和pop操作进行管理。当前的默认图像是堆栈顶部的图像。

举个例子:

print tf.get_default_graph()

with tf.Graph().as_default() as g:
    print tf.get_default_graph()

print tf.get_default_graph()

<tensorflow.python.framework.ops.Graph object at 0x106329fd0>
<tensorflow.python.framework.ops.Graph object at 0x18205cc0d0>
<tensorflow.python.framework.ops.Graph object at 0x10d025fd0>

从上面可以看到,当你在范围内创建一个新的图表并将其用作默认图表时,但离开范围后,图表将是原始图表默认图表。

图片制作工厂

我们来解释一下下面的图片是如何制作的?

当Client使用OP构造函数创建Operation实例时,最终会调用Graph.create_op方法将Operation实例注册到图实例中。

也就是说,Graph一方面充当Operation工厂,负责制作Operation;另一方面,Graph充当Operations的存储库,负责其他Operations的存储、检索、转换和操作。 TensorFlow核心组件系列之Graph的底层机制探索

这个过程通常称为计算图构建。计算图构建过程中,不会触发运行时的OP操作。它只是简单地描述了计算节点之间的依赖关系,并构建了一个DAG图来为整个计算过程创建一个总体规划。

此外,TF还提供了Graph Key类,可以更方便地管理和检索节点信息:

class GraphKeys(object):
    GLOBAL_VARIABLES = "variables"
  # Key to collect local variables that are local to the machine and are not
  # saved/restored.
  LOCAL_VARIABLES = "local_variables"
  # optimizers.
  TRAINABLE_VARIABLES = "trainable_variables"
  SAVERS = "savers"
  # Key to collect weights
  WEIGHTS = "weights"
  # Key to collect biases
  BIASES = "biases"
  # Key to collect activations
  ACTIVATIONS = "activations"
  # Key to collect update_ops
  UPDATE_OPS = "update_ops"
  # Key to collect losses
  LOSSES = "losses"
  ...

  # Key to indicate various ops.
  INIT_OP = "init_op"
  LOCAL_INIT_OP = "local_init_op"
  SUMMARY_OP = "summary_op"
  GLOBAL_STEP = "global_step"

  # Used to count the number of evaluations performed during a single evaluation
  # run.
  EVAL_STEP = "eval_step"
  TRAIN_OP = "train_op"

  # Key for control flow context.
  COND_CONTEXT = "cond_context"
  WHILE_CONTEXT = "while_context"

  # Used to store v2 summary names.
  _SUMMARY_COLLECTION = "_SUMMARY_V2"

  # List of all collections that keep track of variables.
  _VARIABLE_COLLECTIONS = [
      GLOBAL_VARIABLES,
      LOCAL_VARIABLES,
      METRIC_VARIABLES,
      MODEL_VARIABLES,
      TRAINABLE_VARIABLES,
      MOVING_AVERAGE_VARIABLES,
      CONCATENATED_VARIABLES,
      TRAINABLE_RESOURCE_VARIABLES,
  ]

# 用户要快速的检索某类变量可以通过这样的语句
all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

用户可以通过语句tf.get_collection(tf.Graph Keys.TRAINABLE_VARIABLES)快速检索变量,或者自定义分组。

3。 后端(C++)Graph数据结构

Graph(计算图)是一组节点和边。计算图是DAG图。计算图的执行过程将根据DAG拓扑进行排序,OP操作将按顺序启动。其中,如果有多个level 0的节点,TensorFlow可以在运行时实现并发,同时执行多个OP操作,以提高执行效率。

class Graph {
     private:
      // 所有已知的op计算函数的注册表
      FunctionLibraryDefinition ops_;

      // GraphDef版本号
      const std::unique_ptr<VersionDef> versions_;

      // 节点node列表,通过id来访问
      std::vector<Node*> nodes_;

      // node个数
      int64 num_nodes_ = 0;

      // 边edge列表,通过id来访问
      std::vector<Edge*> edges_;

      // graph中非空edge的数目
      int num_edges_ = 0;

      // 已分配了内存,但还没使用的node和edge
      std::vector<Node*> free_nodes_;
      std::vector<Edge*> free_edges_;

     const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
       auto e = AllocEdge();
       e->src_ = source;
       e->dst_ = dest;
       e->src_output_ = x;
       e->dst_input_ = y;
       CHECK(source->out_edges_.insert(e).second);
       CHECK(dest->in_edges_.insert(e).second);
       edges_.push_back(e);
       edge_set_.insert(e);
       return e;
        }
 }

Graph的主肢在后面还有结和棱。节点是计算算子的运算,边是算子需要的数据,或者表示节点之间的依赖关系。 Edge 保存指向源节点和目标节点的指针,从而连接两个节点。

Graph图可以由节点和边组成,通过任意节点和边都可以遍历完整的图。 Graph在进行计算时,会按照拓扑结构依次对各个Node进行op计算,最终得到输出结果。度数为0的节点,即数据所依赖的节点已经准备好,可以同时执行,从而提高运行效率。

系统中有一个标准的Graph。当Future启动时,会添加一个Source节点和一个Sink节点。 Source代表Graph的起始节点,Sink是最终节点。 Source ID为0,Sink ID为1,其他节点ID均大于1。

另外,Graph数据结构提供了多种构建和修改计算图的方式。通过这些方法,我们可以添加和删除节点、连接节点之间的边、设置节点属性和操作。例如,您可以通过调用 graph->AddNode() 方法添加新节点,并通过调用 node->AddInputEdge() 方法添加输入边。

4。总结

Graph前端(Python)定义主要包括:

  1. Graph对象包含一系列Operation对象和Tensor对象,用于表示计算单元和数据单元的集合。
  2. 通过为每个Operation分配唯一的id并将其存储在_nodes_by_id字典中,并按节点名称存储在_nodes_by_name字典中,可以快速索引和检索节点信息。
  3. 在图构建过程中,通过OP构造函数创建一个Operation,并将其添加到当前的Graph实例中。
  4. 标准镜像使用堆栈管理,通过push和pop操作进行管理。当前的默认图像是堆栈顶部的图像。

Graph后端数据结构定义(C++)主要包括:

  1. Graph后端成员包括节点和边。节点和边向量的数组维护在数据结构中。这是定义数据结构的典型图。
  2. 默认的Graph在初始化时会添加一个Source节点和一个Sink节点。 Source节点代表计算图的起始节点,Sink节点代表计算图的结束节点。 Source节点的ID为0,Sink节点的ID为1,所有其他节点的ID都大于1。
  3. Graph数据结构提供了一系列的方法来构建和修改计算图,例如 AddEdge 和 `AddNode`

另外,前端GAG图生成后,通过protobuf序列化和反序列化发送到后端运行和优化。

版权声明

本文仅代表作者观点,不代表Code前端网立场。
本文系作者Code前端网发表,如需转载,请注明页面地址。

发表评论:

◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。

热门