TensorFlow核心组件系列中的Graph核心机制
Graph(计算图)是TensorFlow的核心组件。在TensorFlow中,Graph扮演着重要的角色,用于表示深度学习模型的计算过程和数据流。
Graph的核心机制包括TensorFlow框架的前端和后端系统,它们协同工作来创建和执行计算图。前端系统负责构建计算图并定义模型的结构和操作,而后端系统提供运行时环境并负责对计算图进行操作。通过深入研究Graph的底层机制,我们可以揭示TensorFlow框架的内部工作原理,为模型设计和优化提供指导。
首先简单介绍一下Graph:
在TensorFlow中,Graph是一个由节点(Nodes)和边(Edges)组成的有向无环图。节点代表计算操作,而边代表数据流。每个节点可以接收多个输入边和输出边,建立数据传输和转换关系。这种节点和边的组织结构可以清晰地表示计算过程和模型依赖关系,便于模型优化和并行计算。
在TensorFlow前端中,Operation用于表示图中的节点实例,而Tensor表示图中的边实例,用于连接Op节点。 Op包含节点的计算逻辑和操作类型,而Tensor则承载数据并在节点之间传递。
在C++后端系统中,TensorFlow定义了Edge和Node来表示计算图。另外,Edge和Tensor之间也存在着相互的关系。 Edge是连接Node的纽带,用于传输数据并创建计算依赖,而Tensor则是承载数据到Edge的载体。
Tensor在后端TF系统中维护底层数据的指针和形状信息,并通过引用计数来管理数据生命周期。这样的设计使得TensorFlow可以实现延迟计算和内存复用,提高计算效率和资源利用率。
1。计算图表中Graph的功能特点
Graph的特点之一就是灵活性。 TensorFlow计算图允许用户自由定义模型的结构和操作,以满足各种复杂的深度学习需求。通过添加、删除或修改节点和边,我们可以灵活地设计和修改计算图,以满足不同任务和模型架构的要求。
其次,Graph的特点之一就是计算图的优化。 TensorFlow提供了多种优化技术,可以通过优化计算图来提高模型性能和效率。例如,图剪枝技术可以去除不必要的节点和边,减少计算和内存消耗。同时,利用并行计算技术,可以并行执行计算图中的操作,加速模型训练和推理。此外,量化技术还可以用于精确压缩计算图中的张量,以降低模型的存储成本和计算能力。
此外,Graph还提供了模型可视化和调试的支持。 TensorFlow提供了可视化工具,可以以图形方式呈现计算图,帮助开发人员直观地了解模型的结构和数据流。通过分析计算图,可以定位并解决潜在的问题,提高模型的稳定性和可靠性。
2。前端 (Python) Graph定义
Graph对象将包含许多代表计算单元集合的操作对象。它还间接保存了许多代表数据单元集合的 Tensor 对象。
首先看一下Graph在Python前端侧的定义:
class Graph(object):
def __init__(self):
self._lock = threading.Lock()
self._nodes_by_id = dict() # GUARDED_BY(self._lock)
self._next_id_counter = 0 # GUARDED_BY(self._lock)
self._nodes_by_name = dict() # GUARDED_BY(self._lock)
self._registered_ops = op_def_registry.get_registered_ops()
def _add_op(self, op):
self._check_not_finalized()
if not isinstance(op, (Tensor, Operation)):
raise TypeError("op must be a Tensor or Operation: %s" % op)
with self._lock:
# pylint: disable=protected-access
if op._id in self._nodes_by_id:
raise ValueError("cannot add an op with id %d as it already "
"exists in the graph" % op._id)
if op.name in self._nodes_by_name:
raise ValueError("cannot add op with name %s as that name "
"is already used" % op.name)
self._nodes_by_id[op._id] = op
self._nodes_by_name[op.name] = op
self._version = max(self._version, op._id)
def add_to_collection(name, value):
get_default_graph().add_to_collection(name, value)
# 替换线程默认图
def as_default(self):
return _default_graph_stack.get_controller(self)
# 栈式管理,push pop
@tf_contextlib.contextmanager
def get_controller(self, default):
try:
context.context_stack.push(default.building_function, default.as_default)
finally:
context.context_stack.pop()
def create_op(
self,
op_type,
inputs,
dtypes=None, # pylint: disable=redefined-outer-name
input_types=None,
name=None,
attrs=None,
op_def=None,
compute_shapes=True,
compute_device=True):
for idx, a in enumerate(inputs):
if not isinstance(a, Tensor):
raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
return self._create_op_internal(op_type, inputs, dtypes, input_types, name,
attrs, op_def, compute_device)
从源码中可以看出,为了快速索引图中节点的信息,给它分配了一个唯一的ID当前Graph范围内的每一个操作都有一个数据字典,里面存储的是Graph中的_nodes_by_id。同时,为了快速通过节点名称索引节点信息,Graph中还存储了一个数据字典_nodes_by_name。
在图创建阶段,使用OP构造函数创建OP,并最终添加到当前的Graph实例中。当图被冻结时,节点无法添加到图中,从而允许Graph实例在多个线程之间安全地共享。
默认堆栈管理图
另外,从图管理中可以看出,默认图采用堆栈管理方式,使用push和pop操作进行管理。当前的默认图像是包顶部的图像。
举个例子:
print tf.get_default_graph()
with tf.Graph().as_default() as g:
print tf.get_default_graph()
print tf.get_default_graph()
<tensorflow.python.framework.ops.Graph object at 0x106329fd0>
<tensorflow.python.framework.ops.Graph object at 0x18205cc0d0>
<tensorflow.python.framework.ops.Graph object at 0x10d025fd0>
从上面可以看出,当我们在一个范围内创建一个新的图表并将其作为默认图表时,但是当我们离开该范围后,它就变成了原来的默认图表。
图像创建工厂
让我们解释一下下面的图像是如何创建的?
当客户端使用OP构造函数创建操作实例时,最终会调用Future.create_op方法将操作实例注册到图实例中。
也就是说,一方面,Graph充当了运化的工厂,负责运化的创造;另一方面,Graph作为运化的工厂,负责运化的创造。另一方面,Graph充当Operation的仓库,负责Operation的存储、检索、转换等操作。
这个过程通常称为计算图构建。在创建计算图的过程中,运行时不会执行OP操作。它只是描述计算节点之间的依赖关系,并创建一个DAG图来为整个计算过程创建一个总体规划。
此外,TF还提供了Graph Key类,以便更方便地管理节点和信息检索:
class GraphKeys(object):
GLOBAL_VARIABLES = "variables"
# Key to collect local variables that are local to the machine and are not
# saved/restored.
LOCAL_VARIABLES = "local_variables"
# optimizers.
TRAINABLE_VARIABLES = "trainable_variables"
SAVERS = "savers"
# Key to collect weights
WEIGHTS = "weights"
# Key to collect biases
BIASES = "biases"
# Key to collect activations
ACTIVATIONS = "activations"
# Key to collect update_ops
UPDATE_OPS = "update_ops"
# Key to collect losses
LOSSES = "losses"
...
# Key to indicate various ops.
INIT_OP = "init_op"
LOCAL_INIT_OP = "local_init_op"
SUMMARY_OP = "summary_op"
GLOBAL_STEP = "global_step"
# Used to count the number of evaluations performed during a single evaluation
# run.
EVAL_STEP = "eval_step"
TRAIN_OP = "train_op"
# Key for control flow context.
COND_CONTEXT = "cond_context"
WHILE_CONTEXT = "while_context"
# Used to store v2 summary names.
_SUMMARY_COLLECTION = "_SUMMARY_V2"
# List of all collections that keep track of variables.
_VARIABLE_COLLECTIONS = [
GLOBAL_VARIABLES,
LOCAL_VARIABLES,
METRIC_VARIABLES,
MODEL_VARIABLES,
TRAINABLE_VARIABLES,
MOVING_AVERAGE_VARIABLES,
CONCATENATED_VARIABLES,
TRAINABLE_RESOURCE_VARIABLES,
]
# 用户要快速的检索某类变量可以通过这样的语句
all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
用户可以使用tf.get_collection(tf.Graph Keys.TRAINABLE_VARIABLES groupings)快速检索变量或编辑组。
3。 后端(C++)Graph数据结构
Graph(计算图)是一组节点和边。计算图是 DAG 图。计算图执行过程将按照DAG拓扑排序,并逐步开始OP操作。如果其中存在多个度数为0的节点,TensorFlow可以实现运行时并发,同时执行多个OP操作,以提高执行效率。
class Graph {
private:
// 所有已知的op计算函数的注册表
FunctionLibraryDefinition ops_;
// GraphDef版本号
const std::unique_ptr<VersionDef> versions_;
// 节点node列表,通过id来访问
std::vector<Node*> nodes_;
// node个数
int64 num_nodes_ = 0;
// 边edge列表,通过id来访问
std::vector<Edge*> edges_;
// graph中非空edge的数目
int num_edges_ = 0;
// 已分配了内存,但还没使用的node和edge
std::vector<Node*> free_nodes_;
std::vector<Edge*> free_edges_;
const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
auto e = AllocEdge();
e->src_ = source;
e->dst_ = dest;
e->src_output_ = x;
e->dst_input_ = y;
CHECK(source->out_edges_.insert(e).second);
CHECK(dest->in_edges_.insert(e).second);
edges_.push_back(e);
edge_set_.insert(e);
return e;
}
}
后台Graph的主要成员也是节点和边。节点node是计算算子Operation,边是算子所需的数据或者表示节点之间的依赖关系。边保存指向其源节点和目标节点的指针,连接这两个节点。
Graph图可以由一个节点和一条边组成,并且任意节点和任意边都可以遍历整个图。Graph进行计算时,按照拓扑结构的顺序对各个节点进行运算计算,最终得到输出结果。度数为0的节点,即依赖数据已经准备好的节点,可以并发执行,提高运行效率。
系统默认Graph。当Graph初始化时,会添加一个Source和Sink节点。 Source代表起始节点,Sink代表结束节点。源ID为0,宿ID为1,其他节点的ID均大于1。
此外,Graph数据结构提供了多种创建和修改计算图的方法。通过这些方法,我们可以添加和删除节点、连接节点之间的边、设置节点属性和操作。例如,您可以通过调用 graph->AddNode()
方法添加新节点,并通过调用 node->AddInputEdge()
方法添加输入边。
4。总结
Graph前端定义(Python)主要包括:
- 一个Graph对象包含多个Operation对象和Tensor对象,用来表示计算单元和数据单元的集合。
- 通过为每个操作分配唯一的 ID 并将其存储在 _nodes_by_id 字典中,并按节点名称存储在 _nodes_by_name 字典中,快速索引和检索节点信息。
- 在图构建过程中,使用OP构造函数创建一个操作并将其添加到当前的Graph实例中。
- 默认镜像采用栈管理,通过push和pop操作进行管理。当前的默认图像是包顶部的图像。
后端定义(C++) Graph数据结构主要包括:
- 后端Graph成员包括节点和边。节点和边向量字段维护在数据结构中。它们是定义数据结构的典型图。
- 默认的Graph会在初始化的时候添加一个Source和Sink节点。 Source节点代表计算图的起始节点,Sink节点代表计算图的结束节点。 Source节点的ID为0,Sink节点的ID为1,其他节点的ID均大于1。
- Graph数据结构提供了多种创建和修改计算图的方法,例如作为 AddEdge 和 `AddNode`。
生成前端GAG图后,额外通过protobuf序列化和反序列化发送到后端执行和优化。
版权声明
本文仅代表作者观点,不代表Code前端网立场。
本文系作者Code前端网发表,如需转载,请注明页面地址。
发表评论:
◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。