OwenZhu's Blog

An Overview inside Distributed Tensorflow Workflow

2017/12/19 Share

An Overview inside Distributed Tensorflow Workflow

This article summarizes my understanding of distributed Tensorflow’s workflow. I’d like divide it into following four parts:

  • Create a server;
  • Create a session;
  • Build a computation graph;
  • Run a session.

So let’s start from creating a server.

Create a server

grpc_tensorflow_server.cc

Start to create a new server. Each server changes between these three states: Start, Join and Stop.

1
2
3
4
5
6
int main(int argc, char* argv[]) {
...
TF_QCHECK_OK(tensorflow::NewServer(server_def, &server));
TF_QCHECK_OK(server->Start());
TF_QCHECK_OK(server->Join());
}

grpc_server_lib.cc

Create a Grpc server and call the Init( ) function

Inside the Init( ) function, firstly it creates both master environment and worker environment.

The master environment mainly holds:

  • local devices

  • Worker cache

  • Master session factory

And worker environment holds:

  • Local devices
  • Device_mgr
  • rendezvous_mgr
  • Session_mgr
  • Compute_pool

Then the Grpc server starts one thread for master service and another for worker service respectively. We will see how these two services work in the following sections:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
class GrpcServerFactory : public ServerFactory {
...
Status NewServer(const ServerDef& server_def,
std::unique_ptr<ServerInterface>* out_server) override {
return GrpcServer::Create(server_def, Env::Default(), out_server);
}
};

...

Status GrpcServer::Create(const ServerDef& server_def, Env* env,
std::unique_ptr<ServerInterface>* out_server) {
std::unique_ptr<GrpcServer> ret(
new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
ServiceInitFunction service_func = nullptr;
TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr));
*out_server = std::move(ret);
return Status::OK();
}

...

Status GrpcServer::Init(
ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
const WorkerCreationFunction& worker_func) {
...
master_env_.env = env_;
worker_env_.env = env_;
...
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix,
&master_env_.local_devices));
worker_env_.local_devices = master_env_.local_devices;
worker_env_.device_mgr = new DeviceMgr(worker_env_.local_devices);
worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr
? new RpcRendezvousMgr(&worker_env_)
: rendezvous_mgr_func(&worker_env_);
...
master_impl_ = CreateMaster(&master_env_);
master_service_ = NewGrpcMasterService(
master_impl_.get(), config.operation_timeout_in_ms(), &builder);
worker_impl_ =
worker_func ? worker_func(&worker_env_) : NewGrpcWorker(&worker_env_);
worker_service_ =
NewGrpcWorkerService(worker_impl_.get(), &builder).release();
...
TF_RETURN_IF_ERROR(
WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
...
// Set up worker environment.
worker_env_.session_mgr = new SessionMgr(
&worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
std::unique_ptr<WorkerCacheInterface>(worker_cache),
[this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
WorkerCacheFactoryOptions options(server_def);
return WorkerCacheFactory(options, worker_cache);
});
worker_env_.compute_pool = ComputePool(sess_opts);
...
master_env_.master_session_factory =
[config](
SessionOptions options, const MasterEnv* env,
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<DeviceSet> device_set) {
options.config.MergeFrom(config);
return new MasterSession(options, env, std::move(remote_devs),
std::move(worker_cache), std::move(device_set),
CreateNoOpStatsPublisher);
};
master_env_.worker_cache_factory =
[this](const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache) {
return WorkerCacheFactory(options, worker_cache);
};
// Provide direct access to the master from in-process clients.
LocalMaster::Register(target(), master_impl_.get(),
config.operation_timeout_in_ms());
...
}

...

Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache) {
*worker_cache = NewGrpcWorkerCacheWithLocalWorker(
channel_cache.release(), worker_impl_.get(), name_prefix);
...
}

...

Status GrpcServer::Start() {
mutex_lock l(mu_);
switch (state_) {
case NEW: {
master_thread_.reset(
env_->StartThread(ThreadOptions(), "TF_master_service",
[this] { master_service_->HandleRPCsLoop(); }));
worker_thread_.reset(
env_->StartThread(ThreadOptions(), "TF_worker_service",
[this] { worker_service_->HandleRPCsLoop(); }));
...
}

Starts the master service. The function HandleRPCsLoop( ) handles the out-coming Grpc request. Each request of function invokes the specific function handler.

grpc_master_service.cc

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
  void HandleRPCsLoop() override {
ENQUEUE_REQUEST(CreateSession, true);
ENQUEUE_REQUEST(ExtendSession, false);
for (int i = 0; i < 100; ++i) {
ENQUEUE_REQUEST(PartialRunSetup, false);
ENQUEUE_REQUEST(RunStep, true);
}
ENQUEUE_REQUEST(CloseSession, false);
ENQUEUE_REQUEST(ListDevices, false);
ENQUEUE_REQUEST(Reset, false);

void* tag;
bool ok;
while (cq_->Next(&tag, &ok)) {
UntypedCall<GrpcMasterService>::Tag* callback_tag =
static_cast<UntypedCall<GrpcMasterService>::Tag*>(tag);
if (callback_tag) {
callback_tag->OnCompleted(this, ok);
} else {
// NOTE(mrry): A null `callback_tag` indicates that this is
// the shutdown alarm.
cq_->Shutdown();
}
}
}

...

// RPC handler for creating a session.
void CreateSessionHandler(
MasterCall<CreateSessionRequest, CreateSessionResponse>* call) {
master_impl_->CreateSession(&call->request, &call->response,
[call](const Status& status) {
call->SendResponse(ToGrpcStatus(status));
});
ENQUEUE_REQUEST(CreateSession, true);
}

...

// RPC handler for running one step in a session.
void RunStepHandler(MasterCall<RunStepRequest, RunStepResponse>* call) {
...
master_impl_->RunStep(call_opts, wrapped_request, wrapped_response,
[call, call_opts, wrapped_request, wrapped_response,
trace](const Status& status) {
call->ClearCancelCallback();
delete call_opts;
delete wrapped_request;
delete trace;
call->SendResponse(ToGrpcStatus(status));
});
ENQUEUE_REQUEST(RunStep, true);
}

Master class is responsible for creating and maintaining a master session:

master.cc

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
// Master implements the service MasterSerivce.
//
// A Master maintains the state of live graph computation
// sessions, each session orchestrates both local and remote devices
// to carry out the graph computation.
//
// A Master knows ahead of time local devices available as
// client devices.
//
// A Master discovers remote devices on-demand and keeps track of
// statistics of those remote devices.
//
// Each session analyzes the graph, places nodes across available
// devices, and ultimately drives the graph computation by initiating
// RunGraph on the workers.

void Master::CreateSession(const CreateSessionRequest* req,
CreateSessionResponse* resp, MyClosure done) {
...
status = ValidateExternalGraphDefSyntax(req->graph_def());
...
// Create the worker cache from the computed server_def.
status = env_->worker_cache_factory(worker_cache_factory_options,
&worker_cache);
...
MasterSession* session = env_->master_session_factory(
options, env_, std::move(remote_devices), std::move(worker_cache_ptr),
std::move(device_set));
...
status = session->Create(gdef, worker_cache_factory_options);
...
}

void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req,
MutableRunStepResponseWrapper* resp, MyClosure done) {
...
SchedClosure([this, start_time, session, opts, req, resp, done]() {
Status status = session->Run(opts, *req, resp);
...
});
}

master_session.cc

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
Status MasterSession::Create(GraphDef* graph_def,
const WorkerCacheFactoryOptions& options) {

...
{
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
graph_def, execution_options, &execution_state_));
}
...
return CreateWorkerSessions(options);
...
}

...

Status MasterSession::CreateWorkerSessions(
const WorkerCacheFactoryOptions& options) {
...
struct WorkerGroup {
// The worker name. (Not owned.)
const string* name;

// The worker referenced by name. (Not owned.)
WorkerInterface* worker = nullptr;

// Request and responses used for a given worker.
CreateWorkerSessionRequest request;
CreateWorkerSessionResponse response;
...
};
...
std::vector<WorkerGroup> workers(worker_names.size());
...
// Create all the workers & kick off the computations.
for (size_t i = 0; i < worker_names.size(); ++i) {
...
workers[i].worker = worker_cache_->CreateWorker(worker_names[i]);
...
workers[i].worker->CreateWorkerSessionAsync(&workers[i].request,
&workers[i].response, cb);
...
}

grpc_worker_cache.cc

1
2
3
4
5
6
7
  WorkerInterface* CreateWorker(const string& target) override {
if (target == local_target_) {
return local_worker_;
}else {
...
return NewGrpcRemoteWorker(channel, &completion_queue_, &logger_);
...

grpc_remote_worker.cc

1
2
3
4
5
void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
CreateWorkerSessionResponse* response,
StatusCallback done) override {
IssueRequest(request, response, createworkersession_, std::move(done));
}

grpc_worker_service.cc

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
  void HandleRPCsLoop() override {
// TODO(mrry): This may require performance engineering. We can
// add more threads to service the completion queue, and add more
// of various request types if they are short and frequent.
// Currently we allow unbounded numbers of pending calls for each
// method, by re-enqueuing a request before the previous one
// completes, and we may decide to bound some of the request
// types.
ENQUEUE_REQUEST(GetStatus, false);
ENQUEUE_REQUEST(CreateWorkerSession, false);
ENQUEUE_REQUEST(DeleteWorkerSession, false);
ENQUEUE_REQUEST(CleanupAll, false);
ENQUEUE_REQUEST(RegisterGraph, false);
ENQUEUE_REQUEST(DeregisterGraph, false);

// TODO(mrry): Determine a better policy for enqueuing the appropriate
// number of each request type.
for (int i = 0; i < 1000; ++i) {
EnqueueRecvTensorRequestRaw();
}
for (int i = 0; i < 100; ++i) {
ENQUEUE_REQUEST(RunGraph, true);
}
for (int i = 0; i < 100; ++i) {
ENQUEUE_REQUEST(CleanupGraph, false);
}

ENQUEUE_REQUEST(Logging, false);
ENQUEUE_REQUEST(Tracing, false);

void* tag;
bool ok;

while (cq_->Next(&tag, &ok)) {
UntypedCall<GrpcWorkerService>::Tag* callback_tag =
static_cast<UntypedCall<GrpcWorkerService>::Tag*>(tag);
if (callback_tag) {
callback_tag->OnCompleted(this, ok);
} else {
// NOTE(mrry): A null `callback_tag` indicates that this is
// the shutdown alarm.
cq_->Shutdown();
}
}

...

void CreateWorkerSessionHandler(
WorkerCall<CreateWorkerSessionRequest, CreateWorkerSessionResponse>*
call) {
Schedule([this, call]() {
Status s = worker_->CreateWorkerSession(&call->request, &call->response);
call->SendResponse(ToGrpcStatus(s));
});
ENQUEUE_REQUEST(CreateWorkerSession, false);
}

...

void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
Schedule([this, call]() {
CallOptions* call_opts = new CallOptions;
ProtoRunGraphRequest* wrapped_request =
new ProtoRunGraphRequest(&call->request);
NonOwnedProtoRunGraphResponse* wrapped_response =
new NonOwnedProtoRunGraphResponse(&call->response);
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
[call, call_opts, wrapped_request,
wrapped_response](const Status& s) {
call->ClearCancelCallback();
delete call_opts;
delete wrapped_request;
delete wrapped_response;
call->SendResponse(ToGrpcStatus(s));
});
});
ENQUEUE_REQUEST(RunGraph, true);
}

worker.cc

1
2
3
4
5
6
7
void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
CreateWorkerSessionResponse* response,
StatusCallback done) {
Status s = env_->session_mgr->CreateSession(request->session_handle(),
request->server_def());
done(s);
}

session_mgr.cc

1
2
3
4
5
6
7
8
9
10
Status SessionMgr::CreateSession(const string& session,
const ServerDef& server_def) {
...
std::unique_ptr<DeviceMgr> device_mgr(new DeviceMgr(renamed_devices));
...
std::unique_ptr<WorkerSession> worker_session(new WorkerSession(
session, worker_name, std::unique_ptr<WorkerCacheInterface>(worker_cache),
std::move(device_mgr), std::move(graph_mgr)));
...
}

worker_session.cc

1
2
3
4
5
6
7
8
9
10
11
12
13
WorkerSession::WorkerSession(const string& session_name,
const string& worker_name,
std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<DeviceMgr> device_mgr,
std::unique_ptr<GraphMgr> graph_mgr)
: session_name(session_name),
worker_name(worker_name),
worker_cache(new WorkerFreeListCache(std::move(worker_cache))),
device_mgr(std::move(device_mgr)),
graph_mgr(std::move(graph_mgr)),
cluster_flr(new ClusterFunctionLibraryRuntime(this)) {}

} // namespace tensorflow

Graph execution

A graph definition is passed through the session, and constructed by the master.

graph.proto

1
2
3
4
5
// Represents the graph of operations
message GraphDef {
repeated NodeDef node = 1;
...
};

node_def.proto

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
message NodeDef {
// The name given to this operator. Used for naming inputs,
// logging, visualization, etc. Unique within a single GraphDef.
// Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*".
string name = 1;

// The operation name. There may be custom parameters in attrs.
// Op names starting with an underscore are reserved for internal use.
string op = 2;

// Each input is "node:src_output" with "node" being a string name and
// "src_output" indicating which output tensor to use from "node". If
// "src_output" is 0 the ":0" suffix can be omitted. Regular inputs
// may optionally be followed by control inputs that have the format
// "^node".
repeated string input = 3;

// A (possibly partial) specification for the device on which this
// node should be placed.
// The expected syntax for this string is as follows:
//
// DEVICE_SPEC ::= PARTIAL_SPEC
//
// PARTIAL_SPEC ::= ("/" CONSTRAINT) *
// CONSTRAINT ::= ("job:" JOB_NAME)
// | ("replica:" [1-9][0-9]*)
// | ("task:" [1-9][0-9]*)
// | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") )
//
// Valid values for this string include:
// * "/job:worker/replica:0/task:1/device:GPU:3" (full specification)
// * "/job:worker/device:GPU:3" (partial specification)
// * "" (no specification)
//
// If the constraints do not resolve to a single device (or if this
// field is empty or not present), the runtime will attempt to
// choose a device automatically.
string device = 4;

// Operation-specific graph-construction-time configuration.
// Note that this should include all attrs defined in the
// corresponding OpDef, including those with a value matching
// the default -- this allows the default to change and makes
// NodeDefs easier to interpret on their own. However, if
// an attr with a default is not specified in this list, the
// default will be used.
// The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and
// one of the names from the corresponding OpDef's attr field).
// The values must have a type matching the corresponding OpDef
// attr's type field.
// TODO(josh11b): Add some examples here showing best practices.
map<string, AttrValue> attr = 5;
};

graph_execution.cc

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/* static */ Status GraphExecutionState::MakeForBaseGraph(
GraphDef* graph_def, const GraphExecutionStateOptions& options,
std::unique_ptr<GraphExecutionState>* out_state) {
std::unique_ptr<GraphExecutionState> ret(
new GraphExecutionState(graph_def, options));
...
TF_RETURN_IF_ERROR(ret->InitBaseGraph(BuildGraphOptions()));
...
}

Status GraphExecutionState::InitBaseGraph(const BuildGraphOptions& options) {
...
std::unique_ptr<Graph> new_graph(new Graph(OpRegistry::Global()));
...
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, *graph_def, new_graph.get()));
...
}

graph_constructor.cc

1
2
3
4
5
6
7
8
Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
const GraphDef& gdef, Graph* g) {
ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
return GraphConstructor::Construct(
opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
/*return_tensors=*/nullptr, /*return_nodes=*/nullptr,
/*unused_input_map_keys=*/nullptr);
}

Run session

master_session.cc

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp) {
...
status = DoRunWithLocalExecution(opts, req, resp);
...
}

...

status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
ReffedClientGraph** rcg, bool is_partial) {
...
auto entry = new ReffedClientGraph(
handle_, opts, std::move(client_graph), session_opts_,
stats_publisher_factory_, execution_state_.get(), is_partial,
worker_cache, !should_delete_worker_sessions_);
iter = m->insert({hash, entry}).first;
VLOG(1) << "Preparing to execute new graph";
}
*rcg = iter->second;
...
}

...

Status MasterSession::DoRunWithLocalExecution(
CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp) {
...
ReffedClientGraph* rcg = nullptr;
...
TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
...
pss.collect_partition_graphs = req.options().output_partition_graphs();
...
Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
&cancellation_manager_, false);
if (s.ok()) {
...
// Schedule post-processing and cleanup to be done asynchronously.
rcg->ProcessStats(step_id, &pss, ph.get(), req.options(),
resp->mutable_metadata());
...
}
...
}

...

Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
...
TF_RETURN_IF_ERROR(rcg->RegisterPartitions(popts));
...
}

...

Status MasterSession::ReffedClientGraph::RegisterPartitions(
const PartitionOptions& popts) {
Status s = DoBuildPartitions(popts, &graph_defs);
...
s = DoRegisterPartitions(popts, std::move(graph_defs));
...
}

Status MasterSession::ReffedClientGraph::RunPartitions(
const MasterEnv* env, int64 step_id, int64 execution_count,
PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp, CancellationManager* cm,
const bool is_last_partial_run) {
...
const int num = partitions_.size();
RunManyGraphs calls(num);

for (int i = 0; i < num; ++i) {
const Part& part = partitions_[i];
RunManyGraphs::Call* c = calls.get(i);
c->req.reset(part.worker->CreateRunGraphRequest());
c->resp.reset(part.worker->CreateRunGraphResponse());
...
// Issues RunGraph calls.
for (int i = 0; i < num; ++i) {
const Part& part = partitions_[i];
RunManyGraphs::Call* call = calls.get(i);
TRACEPRINTF("Partition %d %s", i, part.name.c_str());
part.worker->RunGraphAsync(
&call->opts, call->req.get(), call->resp.get(),
std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));
}
...
calls.Wait();
...
// Collects fetches.
Status status = calls.status();
if (status.ok()) {
for (int i = 0; i < num; ++i) {
const Part& part = partitions_[i];
MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
if (iter == part.key_fetch.end()) {
status.Update(errors::Internal("Unexpected fetch key: ",
run_graph_resp->recv_key(j)));
break;
}
const string& fetch = iter->second;
status.Update(
resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
if (!status.ok()) {
break;
}
}
if (pss->collect_timeline) {
pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
}
if (pss->collect_costs) {
CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
for (int j = 0; j < cost_graph->node_size(); ++j) {
resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
cost_graph->mutable_node(j));
}
}
if (pss->collect_partition_graphs) {
protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
resp->mutable_metadata()->mutable_partition_graphs();
for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
partition_graph_defs->Add()->Swap(
run_graph_resp->mutable_partition_graph(i));
}
}
}
}
return status;
}

worker.cc

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
MutableRunGraphResponseWrapper* response,
StatusCallback done) {
...
DoRunGraph(opts, request, response, std::move(done));
...
}

...

void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
MutableRunGraphResponseWrapper* response,
StatusCallback done) {
...
WorkerSession* session =
env_->session_mgr->WorkerSessionForSession(request->session_handle());
GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
Status s = PrepareRunGraph(request, &in, out);
...
session->graph_mgr->ExecuteAsync(
request->graph_handle(), step_id, session, request->exec_opts(),
collector, response, cm, in,
[this, step_id, response, session, cm, out, token, collector, opts,
done](Status s) {
if (s.ok()) {
s = session->graph_mgr->RecvOutputs(step_id, out);
}
...
if (s.ok()) {
for (const auto& p : *out) {
const string& key = p.first;
const Tensor& val = p.second;
response->AddRecv(key, val);
}
}
...
});
}

graph_mgr.cc

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
WorkerSession* session, const ExecutorOpts& opts,
StepStatsCollector* collector,
MutableRunGraphResponseWrapper* response,
CancellationManager* cancellation_manager,
const NamedTensors& in, StatusCallback done) {
// Lookup an item. Holds one ref while executing.
Item* item = nullptr;
{
mutex_lock l(mu_);
auto iter = table_.find(handle);
if (iter != table_.end()) {
item = iter->second;
item->Ref();
}
}
...
CostGraphDef* cost_graph = nullptr;
if (response != nullptr) {
cost_graph = response->mutable_cost_graph();
if (opts.record_partition_graphs()) {
for (const ExecutionUnit& unit : item->units) {
GraphDef graph_def;
unit.graph->ToGraphDef(&graph_def);
response->AddPartitionGraph(graph_def);
}
}
}

RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = rendezvous->Initialize(session);

// Sends values specified by the caller.
if (s.ok()) {
std::vector<string> keys;
std::vector<Tensor> tensors_to_send;
keys.reserve(in.size());
tensors_to_send.reserve(in.size());
for (auto& p : in) {
keys.push_back(p.first);
tensors_to_send.push_back(p.second);
}
s = SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
}
...
StartParallelExecutors(handle, step_id, item, rendezvous, collector,
cost_graph, cancellation_manager,
[this, item, rendezvous, done](const Status& s) {
done(s);
rendezvous->Unref();
item->Unref();
});
}

CATALOG
  1. 1. An Overview inside Distributed Tensorflow Workflow
    1. 1.1. Create a server
    2. 1.2. Create a Grpc server and call the Init( ) function
    3. 1.3. Graph execution
    4. 1.4. Run session