Code前端首页关于Code前端联系我们

TensorFlow核心组件系列中的Graph核心机制

terry 2年前 (2023-09-23) 阅读数 67 #AI人工智能

Graph(计算图)是TensorFlow的核心组件。在TensorFlow中,Graph扮演着重要的角色,用于表示深度学习模型的计算过程和数据流。

Graph的核心机制包括TensorFlow框架的前端和后端系统,它们协同工作来创建和执行计算图。前端系统负责构建计算图并定义模型的结构和操作,而后端系统提供运行时环境并负责对计算图进行操作。通过深入研究Graph的底层机制,我们可以揭示TensorFlow框架的内部工作原理,为模型设计和优化提供指导。

首先简单介绍一下Graph:

在TensorFlow中,Graph是一个由节点(Nodes)和边(Edges)组成的有向无环图。节点代表计算操作,而边代表数据流。每个节点可以接收多个输入边和输出边,建立数据传输和转换关系。这种节点和边的组织结构可以清晰地表示计算过程和模型依赖关系,便于模型优化和并行计算。

在TensorFlow前端中,Operation用于表示图中的节点实例,而Tensor表示图中的边实例,用于连接Op节点。 Op包含节点的计算逻辑和操作类型,而Tensor则承载数据并在节点之间传递。

在C++后端系统中,TensorFlow定义了Edge和Node来表示计算图。另外,Edge和Tensor之间也存在着相互的关系。 Edge是连接Node的纽带,用于传输数据并创建计算依赖,而Tensor则是承载数据到Edge的载体。

Tensor在后端TF系统中维护底层数据的指针和形状信息,并通过引用计数来管理数据生命周期。这样的设计使得TensorFlow可以实现延迟计算和内存复用,提高计算效率和资源利用率。

1。计算图表中Graph的功能特点

Graph的特点之一就是灵活性。 TensorFlow计算图允许用户自由定义模型的结构和操作,以满足各种复杂的深度学习需求。通过添加、删除或修改节点和边,我们可以灵活地设计和修改计算图,以满足不同任务和模型架构的要求。

其次,Graph的特点之一就是计算图的优化。 TensorFlow提供了多种优化技术,可以通过优化计算图来提高模型性能和效率。例如,图剪枝技术可以去除不必要的节点和边,减少计算和内存消耗。同时,利用并行计算技术,可以并行执行计算图中的操作,加速模型训练和推理。此外,量化技术还可以用于精确压缩计算图中的张量,以降低模型的存储成本和计算能力。

此外,Graph还提供了模型可视化和调试的支持。 TensorFlow提供了可视化工具,可以以图形方式呈现计算图,帮助开发人员直观地了解模型的结构和数据流。通过分析计算图,可以定位并解决潜在的问题,提高模型的稳定性和可靠性。

2。前端 (Python) Graph定义

Graph对象将包含许多代表计算单元集合的操作对象。它还间接保存了许多代表数据单元集合的 Tensor 对象。

首先看一下Graph在Python前端侧的定义:

class Graph(object):
       def __init__(self):
         self._lock = threading.Lock()
         self._nodes_by_id = dict()    # GUARDED_BY(self._lock)
         self._next_id_counter = 0     # GUARDED_BY(self._lock)
         self._nodes_by_name = dict()  # GUARDED_BY(self._lock)
         self._registered_ops = op_def_registry.get_registered_ops()

        def _add_op(self, op):
                self._check_not_finalized()
                if not isinstance(op, (Tensor, Operation)):
                  raise TypeError("op must be a Tensor or Operation: %s" % op)
                with self._lock:
                  # pylint: disable=protected-access
                  if op._id in self._nodes_by_id:
                    raise ValueError("cannot add an op with id %d as it already "
                                     "exists in the graph" % op._id)
                  if op.name in self._nodes_by_name:
                    raise ValueError("cannot add op with name %s as that name "
                                     "is already used" % op.name)
                  self._nodes_by_id[op._id] = op
                  self._nodes_by_name[op.name] = op
                  self._version = max(self._version, op._id)

         def add_to_collection(name, value):
                      get_default_graph().add_to_collection(name, value)

         # 替换线程默认图
               def as_default(self):
                   return _default_graph_stack.get_controller(self)

               # 栈式管理,push pop
               @tf_contextlib.contextmanager
               def get_controller(self, default):
                    try:
                      context.context_stack.push(default.building_function, default.as_default)
                    finally:
                      context.context_stack.pop()

         def create_op(
                  self,
                  op_type,
                  inputs,
                  dtypes=None,  # pylint: disable=redefined-outer-name
                  input_types=None,
                  name=None,
                  attrs=None,
                  op_def=None,
                  compute_shapes=True,
                  compute_device=True):
                        for idx, a in enumerate(inputs):
                          if not isinstance(a, Tensor):
                            raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
                        return self._create_op_internal(op_type, inputs, dtypes, input_types, name,
                                                        attrs, op_def, compute_device)

从源码中可以看出,为了快速索引图中节点的信息,给它分配了一个唯一的ID当前Graph范围内的每一个操作都有一个数据字典,里面存储的是Graph中的_nodes_by_id。同时,为了快速通过节点名称索引节点信息,Graph中还存储了一个数据字典_nodes_by_name。

在图创建阶段,使用OP构造函数创建OP,并最终添加到当前的Graph实例中。当图被冻结时,节点无法添加到图中,从而允许Graph实例在多个线程之间安全地共享。

默认堆栈管理图

另外,从图管理中可以看出,默认图采用堆栈管理方式,使用push和pop操作进行管理。当前的默认图像是包顶部的图像。

举个例子:

print tf.get_default_graph()

with tf.Graph().as_default() as g:
    print tf.get_default_graph()

print tf.get_default_graph()

<tensorflow.python.framework.ops.Graph object at 0x106329fd0>
<tensorflow.python.framework.ops.Graph object at 0x18205cc0d0>
<tensorflow.python.framework.ops.Graph object at 0x10d025fd0>

从上面可以看出,当我们在一个范围内创建一个新的图表并将其作为默认图表时,但是当我们离开该范围后,它就变成了原来的默认图表。

图像创建工厂

让我们解释一下下面的图像是如何创建的?

当客户端使用OP构造函数创建操作实例时,最终会调用Future.create_op方法将操作实例注册到图实例中。

也就是说,一方面,Graph充当了运化的工厂,负责运化的创造;另一方面,Graph作为运化的工厂,负责运化的创造。另一方面,Graph充当Operation的仓库,负责Operation的存储、检索、转换等操作。 TensorFlow核心组件系列之Graph的底层机制

这个过程通常称为计算图构建。在创建计算图的过程中,运行时不会执行OP操作。它只是描述计算节点之间的依赖关系,并创建一个DAG图来为整个计算过程创建一个总体规划。

此外,TF还提供了Graph Key类,以便更方便地管理节点和信息检索:

class GraphKeys(object):
    GLOBAL_VARIABLES = "variables"
  # Key to collect local variables that are local to the machine and are not
  # saved/restored.
  LOCAL_VARIABLES = "local_variables"
  # optimizers.
  TRAINABLE_VARIABLES = "trainable_variables"
  SAVERS = "savers"
  # Key to collect weights
  WEIGHTS = "weights"
  # Key to collect biases
  BIASES = "biases"
  # Key to collect activations
  ACTIVATIONS = "activations"
  # Key to collect update_ops
  UPDATE_OPS = "update_ops"
  # Key to collect losses
  LOSSES = "losses"
  ...

  # Key to indicate various ops.
  INIT_OP = "init_op"
  LOCAL_INIT_OP = "local_init_op"
  SUMMARY_OP = "summary_op"
  GLOBAL_STEP = "global_step"

  # Used to count the number of evaluations performed during a single evaluation
  # run.
  EVAL_STEP = "eval_step"
  TRAIN_OP = "train_op"

  # Key for control flow context.
  COND_CONTEXT = "cond_context"
  WHILE_CONTEXT = "while_context"

  # Used to store v2 summary names.
  _SUMMARY_COLLECTION = "_SUMMARY_V2"

  # List of all collections that keep track of variables.
  _VARIABLE_COLLECTIONS = [
      GLOBAL_VARIABLES,
      LOCAL_VARIABLES,
      METRIC_VARIABLES,
      MODEL_VARIABLES,
      TRAINABLE_VARIABLES,
      MOVING_AVERAGE_VARIABLES,
      CONCATENATED_VARIABLES,
      TRAINABLE_RESOURCE_VARIABLES,
  ]

# 用户要快速的检索某类变量可以通过这样的语句
all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

用户可以使用tf.get_collection(tf.Graph Keys.TRAINABLE_VARIABLES groupings)快速检索变量或编辑组。

3。 后端(C++)Graph数据结构

Graph(计算图)是一组节点和边。计算图是 DAG 图。计算图执行过程将按照DAG拓扑排序,并逐步开始OP操作。如果其中存在多个度数为0的节点,TensorFlow可以实现运行时并发,同时执行多个OP操作,以提高执行效率。

class Graph {
     private:
      // 所有已知的op计算函数的注册表
      FunctionLibraryDefinition ops_;

      // GraphDef版本号
      const std::unique_ptr<VersionDef> versions_;

      // 节点node列表,通过id来访问
      std::vector<Node*> nodes_;

      // node个数
      int64 num_nodes_ = 0;

      // 边edge列表,通过id来访问
      std::vector<Edge*> edges_;

      // graph中非空edge的数目
      int num_edges_ = 0;

      // 已分配了内存,但还没使用的node和edge
      std::vector<Node*> free_nodes_;
      std::vector<Edge*> free_edges_;

     const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
       auto e = AllocEdge();
       e->src_ = source;
       e->dst_ = dest;
       e->src_output_ = x;
       e->dst_input_ = y;
       CHECK(source->out_edges_.insert(e).second);
       CHECK(dest->in_edges_.insert(e).second);
       edges_.push_back(e);
       edge_set_.insert(e);
       return e;
        }
 }

后台Graph的主要成员也是节点和边。节点node是计算算子Operation,边是算子所需的数据或者表示节点之间的依赖关系。边保存指向其源节点和目标节点的指针,连接这两个节点。

Graph图可以由一个节点和一条边组成,并且任意节点和任意边都可以遍历整个图。Graph进行计算时,按照拓扑结构的顺序对各个节点进行运算计算,最终得到输出结果。度数为0的节点,即依赖数据已经准备好的节点,可以并发执行,提高运行效率。

系统默认Graph。当Graph初始化时,会添加一个Source和Sink节点。 Source代表起始节点,Sink代表结束节点。源ID为0,宿ID为1,其他节点的ID均大于1。

此外,Graph数据结构提供了多种创建和修改计算图的方法。通过这些方法,我们可以添加和删除节点、连接节点之间的边、设置节点属性和操作。例如,您可以通过调用 graph->AddNode() 方法添加新节点,并通过调用 node->AddInputEdge() 方法添加输入边。

4。总结

Graph前端定义(Python)主要包括:

  1. 一个Graph对象包含多个Operation对象和Tensor对象,用来表示计算单元和数据单元的集合。
  2. 通过为每个操作分配唯一的 ID 并将其存储在 _nodes_by_id 字典中,并按节点名称存储在 _nodes_by_name 字典中,快速索引和检索节点信息。
  3. 在图构建过程中,使用OP构造函数创建一个操作并将其添加到当前的Graph实例中。
  4. 默认镜像采用栈管理,通过push和pop操作进行管理。当前的默认图像是包顶部的图像。

后端定义(C++) Graph数据结构主要包括:

  1. 后端Graph成员包括节点和边。节点和边向量字段维护在数据结构中。它们是定义数据结构的典型图。
  2. 默认的Graph会在初始化的时候添加一个Source和Sink节点。 Source节点代表计算图的起始节点,Sink节点代表计算图的结束节点。 Source节点的ID为0,Sink节点的ID为1,其他节点的ID均大于1。
  3. Graph数据结构提供了多种创建和修改计算图的方法,例如作为 AddEdge 和 `AddNode`

生成前端GAG图后,额外通过protobuf序列化和反序列化发送到后端执行和优化。

版权声明

本文仅代表作者观点,不代表Code前端网立场。
本文系作者Code前端网发表,如需转载,请注明页面地址。

发表评论:

◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。

热门