Skip to content

Commit

Permalink
Added id to the jobAPI swarm_script_executor_cifar10 component deploy (
Browse files Browse the repository at this point in the history
…#2678)

* Added id to the swarm_script_executor_cifar10 component deploy.

* codestyle fix.

* Changed to use job.as_id().

* codestyle fix.

* changed to use job.as_id(shareable_generator) for shareable_generator_id.

* removed the un-necessary job.to() calls.

---------

Co-authored-by: Chester Chen <[email protected]>
Co-authored-by: Sean Yang <[email protected]>
  • Loading branch information
3 people authored Aug 2, 2024
1 parent 5c63229 commit 4b32f27
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions examples/getting_started/pt/swarm_script_executor_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,22 @@
executor = ScriptExecutor(task_script_path=train_script)
job.to(executor, f"site-{i}", gpu=0, tasks=["train", "validate", "submit_model"])

client_controller = SwarmClientController()
job.to(client_controller, f"site-{i}", tasks=["swarm_*"])

client_controller = CrossSiteEvalClientController()
job.to(client_controller, f"site-{i}", tasks=["cse_*"])

# In swarm learning, each client acts also as an aggregator
aggregator = InTimeAccumulateWeightedAggregator(expected_data_kind=DataKind.WEIGHTS)
job.to(aggregator, f"site-{i}")

# In swarm learning, each client uses a model persistor and shareable_generator
job.to(PTFileModelPersistor(model=Net()), f"site-{i}")
job.to(SimpleModelShareableGenerator(), f"site-{i}")
persistor = PTFileModelPersistor(model=Net())
shareable_generator = SimpleModelShareableGenerator()

client_controller = SwarmClientController(
aggregator_id=job.as_id(aggregator),
persistor_id=job.as_id(persistor),
shareable_generator_id=job.as_id(shareable_generator),
)
job.to(client_controller, f"site-{i}", tasks=["swarm_*"])

client_controller = CrossSiteEvalClientController()
job.to(client_controller, f"site-{i}", tasks=["cse_*"])

# job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir")

0 comments on commit 4b32f27

Please sign in to comment.