-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LightningCLI doesn't parse callback args correctly if more than one args for each callback #15007
Comments
@junwang-wish is it possible for you to post a minimal python script that reproduces this? |
@mauvilsa thx, here u go, say this is named from main_utils import LLMData
import pytorch_lightning as pl
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning import trainer
from pytorch_lightning.callbacks import BasePredictionWriter
class PredictionWriter_Conditional_LM(BasePredictionWriter):
def __init__(self, output_dir: str, write_interval: str):
super().__init__(write_interval)
class LLM_Inference_Base(pl.LightningModule):
def __init__(self, llm_type: str, ckpt_path: str, config_path: str, output_dir: str,
write_interval: str, **kwargs):
super().__init__()
self.save_hyperparameters()
def cli_main():
cli = LightningCLI()
if __name__ == "__main__":
cli_main() If I run python tmp.py predict \
--model=LLM_Inference_Base \
--model.llm_type="conditional_lm" \
--model.ckpt_path="models/model/version_1/epoch=0-step=100.ckpt" \
--model.output_dir="models/model/version_1/config.yaml" \
--model.write_interval="batch" \
--data=LLMData \
--data.data_source_yaml_path="datasets/data/data.yaml" \
--data.model_name="t5-base" \
--trainer.callbacks+=PredictionWriter_Conditional_LM \
--trainer.callbacks.write_interval="batch" \
--trainer.callbacks.output_dir="models/model/version_1" \
--print_config u would get # pytorch_lightning==1.7.7
seed_everything: true
trainer:
logger: true
enable_checkpointing: true
callbacks:
- class_path: __main__.PredictionWriter_Conditional_LM
init_args:
output_dir: models/model/version_1
write_interval: null
default_root_dir: null
gradient_clip_val: null
gradient_clip_algorithm: null
num_nodes: 1
num_processes: null
devices: null
gpus: null
auto_select_gpus: false
tpu_cores: null
ipus: null
enable_progress_bar: true
overfit_batches: 0.0
track_grad_norm: -1
check_val_every_n_epoch: 1
fast_dev_run: false
accumulate_grad_batches: null
max_epochs: null
min_epochs: null
max_steps: -1
min_steps: null
max_time: null
limit_train_batches: null
limit_val_batches: null
limit_test_batches: null
limit_predict_batches: null
val_check_interval: null
log_every_n_steps: 50
accelerator: null
strategy: null
sync_batchnorm: false
precision: 32
enable_model_summary: true
weights_save_path: null
num_sanity_val_steps: 2
resume_from_checkpoint: null
profiler: null
benchmark: null
deterministic: null
reload_dataloaders_every_n_epochs: 0
auto_lr_find: false
replace_sampler_ddp: true
detect_anomaly: false
auto_scale_batch_size: false
plugins: null
amp_backend: native
amp_level: null
move_metrics_to_cpu: false
multiple_trainloader_mode: max_size_cycle
return_predictions: null
ckpt_path: null
model:
class_path: __main__.LLM_Inference_Base
init_args:
llm_type: conditional_lm
ckpt_path: models/model/version_1/epoch=0-step=100.ckpt
config_path: null
output_dir: models/model/version_1/config.yaml
write_interval: batch
data:
class_path: main_utils.LLMData
init_args:
data_source_yaml_path: datasets/data/data.yaml
model_name: t5-base
raw_cache_dir: /data/junwang/.cache/general
batch_size: 16
overwrite_cache: false
max_length: 250
predict_on_test: true
num_workers: 80
max_length_out: 100
cache_dir: null
force_download: false
resume_download: false
proxies: null
use_auth_token: null
local_files_only: false
revision: null
trust_remote_code: null
subfolder: '' Notice that |
Great, thank you! |
@junwang-wish thank you very much for reporting. This was a bug in jsonargparse, fixed in commit 3337a0e and just released as version 4.15.1. Please update the package (e.g. |
Bug description
I'm following Pytorch Lightning 1.7.7 Doc to specify LightningCLI args:
Ideally, both
trainer.callbacks.output_dir
andtrainer.callbacks.write_interval
get passed in to instantiatePredictionWriter_Conditional_LM
. However, I getNote that only the last used
--trainer.callbacks.output_dir="models/model/version_1"
gets passed in, and not--trainer.callbacks.write_interval="batch"
.How to reproduce the bug
No response
Error messages and logs
Environment
More info
No response
The text was updated successfully, but these errors were encountered: