Skip to content

Commit

Permalink
add offload (PaddlePaddle#312)
Browse files Browse the repository at this point in the history
  • Loading branch information
qingshui authored and danleifeng committed Sep 12, 2023
1 parent 32a52bd commit 890c72d
Show file tree
Hide file tree
Showing 7 changed files with 660 additions and 116 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3501,7 +3501,7 @@ void GraphDataGenerator::DoSageForTrain() {
if (total_instance == 0) {
break;
}

ins_buf = reinterpret_cast<uint64_t *>(d_ins_buf_[tensor_pair_idx]->ptr());
ins_cursor = ins_buf + ins_buf_pair_len_[tensor_pair_idx] * 2 - total_instance;
auto final_sage_nodes = GenerateSampleGraph(ins_cursor,
Expand Down
22 changes: 21 additions & 1 deletion paddle/fluid/framework/device_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,18 @@ class CPUWorkerBase : public DeviceWorker {
protected:
int thread_id_;
};

class HogwildWorker : public CPUWorkerBase {
struct OffLoadVarInfo {
std::vector<std::string> copy_vars;
std::vector<std::string> backup_vars;
template<typename TCopyer>
void CopyInputs(const Scope* root,
const platform::Place& place,
Scope* scope,
TCopyer *copyer);
template<typename TCopyer>
void BackUpInputs(Scope* root, Scope* scope, TCopyer *copyer);
};
public:
HogwildWorker() {}
virtual ~HogwildWorker() {}
Expand All @@ -287,6 +297,8 @@ class HogwildWorker : public CPUWorkerBase {
// build thread sharding depends
void BuildShardingDepends(const ProgramDesc& program);
int IsParameter(const std::string& name, bool full_match);
bool IsNeedOffload(const std::string &name);
size_t AdjustOffloadOps(const ProgramDesc &program);

std::vector<std::string> op_names_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
Expand All @@ -312,7 +324,15 @@ class HogwildWorker : public CPUWorkerBase {
std::vector<std::string> shard_dump_params_;
std::vector<std::string> shard_dump_fields_;
std::multiset<std::string> free_param_vars_;
bool is_multi_node_ = false;
bool sharding_mode_ = false;
bool enable_adjust_op_order_ = false;
// offload vars
bool is_offload_communication_ = false;
bool is_offload_param_ = false;
std::vector<std::string> offload_exts_;
std::multiset<std::string> offload_names_;
std::unordered_map<const OperatorBase*, OffLoadVarInfo> offload_vars_;
};

class DownpourWorker : public HogwildWorker {
Expand Down
Loading

0 comments on commit 890c72d

Please sign in to comment.