-
Notifications
You must be signed in to change notification settings - Fork 286
/
base.yml
237 lines (204 loc) · 10.9 KB
/
base.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This sentinel is a reminder to choose a real run name.
# If there is already a checkpoint under this run, that checkpoint will auto-resume.
run_name: ""
model_name: "default" # override config settings to match a specific model. other than the override, nothing should use this!
normalization_layer_epsilon: 1.e-05
################################## CHECKPOINTING ##################################
# Checkpointing makes the following choices in the following order, starting with (1):
# (1) If there is already a checkpoint for this run_name, we load the latest entire checkpoint.
# This ensures if we're resuming a run after preemption or hardware failure we lose minimum state.
# (2) Same priority and mutually exclusive -- you can't set both!
# * If load_parameters_path is set, we load a parameter only checkpoint from that path.
# * If load_full_state_path is set, we load a full state checkpoint from that path.
# (3) We don't load a checkpoint and initialize state instead!
# Loads a just parameters from a specific directory
# e.g. gs://my-base-output-directory/my-previous-run-name/checkpoints/default/NUMBER or NUMBER/default
load_parameters_path: ""
# Loads a full checkpoint including optimizer state and step count from a specific directory
# e.g. gs://my-base-output-directory/my-previous-run-name/checkpoints/default/NUMBER or NUMBER/default
load_full_state_path: ""
# If enable_checkpointing is true, an asynchronous checkpointer will be used if
# async_checkpointing is true, else a synchronous one is used. If you have
# problems with the checkpointer we recommend trying the synchronous one.
enable_checkpointing: True
async_checkpointing: True
checkpoint_period: 10_000
force_unroll: False # during generate_param_only_checkpoint should we unroll the loop?
############################### END CHECKPOINTING ##################################
reuse_example_batch: 0 # for testing TPU performance, this options repeated uses the same batch.
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
gcs_metrics: False
# If true save config to GCS in {base_output_directory}/{run_name}/
save_config_to_gcs: False
# Activation dtypes.
dtype: "bfloat16"
quantization: ""
decoder_block: "llama2" # which style of DecoderBlock to use.
# Global parameter scale needs to be a power of 2. If you want finer grained control of the model sizes
# then you should explicitly set base_embed_dim, base_num_query_heads, base_num_kv_heads,
# base_mlp_dim, base_num_decoder_layers and/or head_dim.
weight_dtype: float32
global_parameter_scale: 1
base_emb_dim: 2048
base_num_query_heads: 16
base_num_kv_heads: 16
base_mlp_dim: 7168
base_num_decoder_layers: 16
head_dim: 128
# mixture of experts (moe)
num_experts: 1
num_experts_per_tok: 1
mlp_activations: ["silu", "linear"]
dropout_rate: 0
logits_via_embedding: False
normalize_embedding_logits: True # whether to normlize pre-softmax logits if logits_via_embedding is true
logits_dot_in_fp32: True # whether to use fp32 in logits_dense or shared_embedding dot product for stability
# proj, minimal, full, or none
remat_policy: 'full'
scan_layers: True
param_scan_axis: 1
attention: 'flash' # Supported attention: dot_product, flash, gpu_flash_xla, gpu_flash_triton
# Combine matmuls for QKV and MLP
fused_qkv: False
fused_mlp: False
record_internal_nn_metrics: 0
# Output directory
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
base_output_directory: ""
# Jax cache directory
jax_cache_dir: "~/jax_cache"
# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' and 'cpu'
# Parallelism
mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]],
['activation_heads', ['tensor','sequence']],
['activation_length', 'sequence'],
['activation_embed', 'tensor'],
['activation_mlp', 'tensor'],
['activation_kv', 'tensor'],
['activation_vocab', ['tensor', 'sequence']],
['activation_vocab', 'tensor'],
['activation_vocab', 'sequence'],
['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']],
['vocab', ['tensor', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence']],
['embed', ['fsdp', 'sequence']],
['heads', ['tensor', 'autoregressive']],
['kv', []],
['cache_batch', []],
['cache_heads', ['autoregressive']],
['cache_kv', []],
['cache_sequence', []],
]
data_sharding: [['data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]
# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: 1
dcn_fsdp_transpose_parallelism: 1
dcn_sequence_parallelism: 1 # never recommended
dcn_tensor_parallelism: 1 # never recommeneded
dcn_autoregressive_parallelism: 1 # never recommended
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_fsdp_transpose_parallelism: 1
ici_sequence_parallelism: 1
ici_tensor_parallelism: 1
ici_autoregressive_parallelism: 1
# Dataset
# Replace with your path given as argument in download_dataset.sh, e.g. "gs://my-maxtext-dataset/"
dataset_path: ""
vocab_size: 32_000 # powers of 2 for sharding
tokenizer_path: "assets/tokenizer.llama2"
dataset_name: 'c4/en:3.0.1'
eval_dataset_name: 'c4/en:3.0.1'
eval_split: 'validation'
per_device_batch_size: 12.0
eval_per_device_batch_size: 0
max_corpus_chars: 10_000_000
dataset_type: c4 # must be c4 or synthetic
# Setting for grain
grain_worker_count: 4
# Training loop
steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps
log_period: 100 # Flushes Tensorboard
# We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
# Learning rate schedule has either two or three parts:
# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
# 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps
# 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps.
# The zero learning rate section can be used to more accurately measure the fully trained model's performance.
learning_rate: 3.e-5
cosine_learning_rate_final_fraction: 0.1
warmup_steps_fraction: 0.1
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.
max_target_length: 2048 # Maximum sequence length
max_prefill_predict_length: 64 # Maximum length for the prefill when doing autoregression
prompt: "I love to" # Prompt for language model sampling.
load_from_prefill_dir: False # If true, decode.py doesn't "prefill" but just reads from directory
prefill_cache_dir: "" # If set and load_from_prefill_dir, decode.py reads from directory. If set, decode.py writes to directory
autoregressive_decode_assert: ""
enable_profiler: False
# If set to true, upload all profiler xplane results from all hosts. Otherwise, only upload the xplane reuslt from the first host.
upload_all_profiler_results: False
# Skip first n steps for profiling, to omit things like compilation and to give
# the iteration time a chance to stabilize.
skip_first_n_steps_for_profiler: 1
# Profile for a small number of steps to avoid a large profile file size.
profiler_steps: 5
# When dropout is false the model is a deterministic function of the
# data_shuffle_seed and init_weights_seed (i.e. reproducible losses)
enable_dropout: True
enable_data_shuffling: True
data_shuffle_seed: 0
init_weights_seed: 0
# You may disable clipping by setting gradient_clipping_threshold to zero.
gradient_clipping_threshold: 1.0
# AdamW optimizer parameters
# We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
opt_type: "adamw" # one of "adam_pax" or "adamw"
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
adam_b2: 0.95 # Exponential decay rate to track the second moment of past gradients.
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
adam_eps_root: 0. # A small constant applied to denominator inside the square root.
adam_weight_decay: 0.1 # AdamW Weight decay
# Stack trace parameters
collect_stack_trace: False
stack_trace_to_cloud: False # Uploads to cloud logging if True, else to the console if False.
stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds.
# Use iota operator in Embed
use_iota_embed: False
# use positional embedding
use_untrainable_positional_embedding: False
trainable_position_size: -1 # enable gpt3 position embedding with a positive trainable_position_size
# Ahead of time Compilation (aka AOT)
# Only set these arguments if you are running train_compile or loading a compiled train step.
compiled_trainstep_file: "" # Name of saved serialized compiled train_step, e.g. compiled_train_v5e-256.pickle
compile_topology: '' # Target hardware version, e.g. 'v5e-256'
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
decode_sampling_strategy: "greedy" # decode_sampling_strategy should be one of greedy, weighted, nucleus, or topk
decode_sampling_nucleus_p: -1 # set if you're doing nucleus / top-p
decode_sampling_top_k: 0 # set if you're doing top-k
decode_sampling_temperature: 1.
eval_interval: -1 # the specific number of train step between eval_step
target_eval_loss: 0. # early stop once reaching target eval_loss