diff --git a/.gitignore b/.gitignore index 6b6db06fb..267995984 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,13 @@ +## HIP-compiled kernels etc. +*hip* +# +local_examples/ +logs/ +trash/ +kb-runs-gpt/ +ds_configs/ +gpt2-tokenizer/ +smi-output/ # tests # megatron autogenerated indices tests/data/*/*npy diff --git a/examples/run_evalharness_deepspeed.md b/examples/run_evalharness_deepspeed.md index 695d9d0aa..60f380d9c 100644 --- a/examples/run_evalharness_deepspeed.md +++ b/examples/run_evalharness_deepspeed.md @@ -15,6 +15,7 @@ Get lm-eval harness (https://github.com/EleutherAI/lm-evaluation-harness) and `b start-prod pip install best-download==0.0.7 pip install git+https://github.com/EleutherAI/lm-evaluation-harness +pip install --upgrade scipy ``` 2. Pre-download needed datasets diff --git a/examples/run_evalharness_lumi.sh b/examples/run_evalharness_lumi.sh new file mode 100644 index 000000000..1721d91d2 --- /dev/null +++ b/examples/run_evalharness_lumi.sh @@ -0,0 +1,113 @@ +#!/bin/bash +#SBATCH --exclude=nid005159 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=256G +#SBATCH -p eap +#SBATCH -t 2-0:00:00 +#SBATCH --gpus-per-node=mi250:1 +#SBATCH --exclusive=user +#SBATCH --hint=nomultithread +#SBATCH --account=project_462000119 +#SBATCH -o logs/%j.out +#SBATCH -e logs/%j.err + +# if run without sbatch, invoke here +if [ -z $SLURM_JOB_ID ]; then + mkdir -p logs + sbatch "$0" + exit +fi + +set -euo pipefail + +# symlink logs/latest_eval.out and logs/latest_eval.err +ln -f -s $SLURM_JOB_ID.out logs/latest_eval.out +ln -f -s $SLURM_JOB_ID.err logs/latest_eval.err + +# Data +CHECKPOINT_PATH=/scratch/project_462000119/muennighoff/nov-2022-optimization/checkpoints/global_step10 +VARIANT=global_step10 + +export HF_DATASETS_OFFLINE=1 +export HF_DATASETS_CACHE=/scratch/project_462000119/ds_cache + +VOCAB_FILE="gpt2/vocab.json" +MERGE_FILE="gpt2/merges.txt" + +PP_SIZE=1 +TP_SIZE=1 +# different from the training MICRO_BATCH_SIZE - no optim memory, so can do bigger BS +# make as big as it can fit into gpu w/o OOM, but not too close to 100% +EVAL_MICRO_BATCH_SIZE=1 +MICRO_BS_MULTIPLIER=1 + +# Model parameters +SEQ_LEN=2048 + +# Dummy arguments +MEGATRON_REQUIRED_ARGS=" \ + --num-layers -1 \ + --hidden-size -1 \ + --num-attention-heads -1 \ + --seq-length -1 \ + --max-position-embeddings -1 \ +" + +ZERO_STAGE=0 + +mkdir -p ds_configs +DS_CONFIG_PATH="ds_configs/$SLURM_JOB_ID.json" + +cat < $DS_CONFIG_PATH +{ + "train_micro_batch_size_per_gpu": 1, + "train_batch_size": 1, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": $ZERO_STAGE + }, + "bf16": { + "enabled": true + }, + "steps_per_print": 2000, + "wall_clock_breakdown": false +} +EOF + +DEEPSPEED_ARGS=" \ + --deepspeed \ + --deepspeed_config $DS_CONFIG_PATH \ + --zero-stage $ZERO_STAGE \ + " + +CMD="Megatron-DeepSpeed/tasks/eval_harness/evaluate.py \ + --load $CHECKPOINT_PATH \ + --results_path $VARIANT-results.json \ + --tensor-model-parallel-size $TP_SIZE \ + --pipeline-model-parallel-size $PP_SIZE \ + --vocab-file $VOCAB_FILE \ + --merge-file $MERGE_FILE \ + --micro-batch-size $EVAL_MICRO_BATCH_SIZE \ + --no-load-optim \ + --no-load-rng \ + --bf16 \ + --inference \ + --seq-length $SEQ_LEN \ + --task_list copa,piqa,rte,winogrande,hendrycksTest-abstract_algebra,hendrycksTest-anatomy,hendrycksTest-astronomy,hendrycksTest-business_ethics,hendrycksTest-clinical_knowledge,hendrycksTest-college_biology,hendrycksTest-college_chemistry,hendrycksTest-college_computer_science,hendrycksTest-college_mathematics,hendrycksTest-college_medicine,hendrycksTest-college_physics,hendrycksTest-computer_security,hendrycksTest-conceptual_physics,hendrycksTest-econometrics,hendrycksTest-electrical_engineering,hendrycksTest-elementary_mathematics,hendrycksTest-formal_logic,hendrycksTest-global_facts,hendrycksTest-high_school_biology,hendrycksTest-high_school_chemistry,hendrycksTest-high_school_computer_science,hendrycksTest-high_school_european_history,hendrycksTest-high_school_geography,hendrycksTest-high_school_government_and_politics,hendrycksTest-high_school_macroeconomics,hendrycksTest-high_school_mathematics,hendrycksTest-high_school_microeconomics,hendrycksTest-high_school_physics,hendrycksTest-high_school_psychology,hendrycksTest-high_school_statistics,hendrycksTest-high_school_us_history,hendrycksTest-high_school_world_history,hendrycksTest-human_aging,hendrycksTest-human_sexuality,hendrycksTest-international_law,hendrycksTest-jurisprudence,hendrycksTest-logical_fallacies,hendrycksTest-machine_learning,hendrycksTest-management,hendrycksTest-marketing,hendrycksTest-medical_genetics,hendrycksTest-miscellaneous,hendrycksTest-moral_disputes,hendrycksTest-moral_scenarios,hendrycksTest-nutrition,hendrycksTest-philosophy,hendrycksTest-prehistory,hendrycksTest-professional_accounting,hendrycksTest-professional_law,hendrycksTest-professional_medicine,hendrycksTest-professional_psychology,hendrycksTest-public_relations,hendrycksTest-security_studies,hendrycksTest-sociology,hendrycksTest-us_foreign_policy,hendrycksTest-virology,hendrycksTest-world_religions \ + --intermed_results \ + --adaptive_seq_len \ + --micro_bs_multiplier $MICRO_BS_MULTIPLIER \ + $MEGATRON_REQUIRED_ARGS \ + $DEEPSPEED_ARGS \ + " + +echo $CMD + +echo "START $SLURM_JOBID: $(date)" + +srun --label launch.sh $CMD + +echo "END $SLURM_JOBID: $(date)" + diff --git a/megatron/arguments.py b/megatron/arguments.py index c18235a78..d2499d149 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -549,6 +549,12 @@ def _add_training_args(parser): group.add_argument('--no-bias-dropout-fusion', action='store_false', help='Disable bias and dropout fusion.', dest='bias_dropout_fusion') + group.add_argument('--no-layer-norm-fusion', action='store_false', + help='Disable fused layer norm.', + dest='layer_norm_fusion') + group.add_argument('--no-optimizer-fusion', action='store_false', + help='Disable FusedAdam/FusedSGD norm.', + dest='optimizer_fusion') group.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'], help='Optimizer function') diff --git a/megatron/fused_kernels/__init__.py b/megatron/fused_kernels/__init__.py index e2ac2567b..bdc654c39 100644 --- a/megatron/fused_kernels/__init__.py +++ b/megatron/fused_kernels/__init__.py @@ -17,81 +17,97 @@ import pathlib import subprocess +import torch from torch.utils import cpp_extension +# Setting this param to a list has a problem of generating different +# compilation commands (with diferent order of architectures) and +# leading to recompilation of fused kernels. Set it to empty string +# to avoid recompilation and assign arch flags explicity in +# extra_cuda_cflags below +os.environ["TORCH_CUDA_ARCH_LIST"] = "" + def load(args): - # Setting this param to a list has a problem of generating different - # compilation commands (with diferent order of architectures) and - # leading to recompilation of fused kernels. Set it to empty string - # to avoid recompilation and assign arch flags explicity in - # extra_cuda_cflags below - # - # but if a user wants to set an explicit list of archs to compile to, then let that list - # through: - arch_list = os.environ.get('TORCH_CUDA_ARCH_LIST', None) - if arch_list is None: - os.environ["TORCH_CUDA_ARCH_LIST"] = "" - - # # Check if cuda 11 is installed for compute capability 8.0 - # cc_flag = [] - # _, bare_metal_major, _ = _get_cuda_bare_metal_version( - # cpp_extension.CUDA_HOME) - # if int(bare_metal_major) >= 11: - # cc_flag.append('-gencode') - # cc_flag.append('arch=compute_80,code=sm_80') + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] + if torch.version.hip is None: + _, bare_metal_major, _ = _get_cuda_bare_metal_version( + cpp_extension.CUDA_HOME) + if int(bare_metal_major) >= 11: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_80,code=sm_80') # Build path srcpath = pathlib.Path(__file__).parent.absolute() buildpath = srcpath / 'build' - buildpath.mkdir(parents=True, exist_ok=True) + _create_build_dir(buildpath) # Helper function to build the kernels. - def _cpp_extention_load_helper(name, sources, extra_cuda_flags): + def _cpp_extention_load_helper(name, sources, extra_cuda_flags, extra_include_paths): + if torch.version.hip is not None: + extra_cuda_cflags=['-O3'] + extra_cuda_flags + cc_flag + else: + extra_cuda_cflags=['-O3', + '-gencode', 'arch=compute_70,code=sm_70', + '--use_fast_math'] + extra_cuda_flags + cc_flag + return cpp_extension.load( name=name, sources=sources, build_directory=buildpath, extra_cflags=['-O3',], - extra_cuda_cflags=['-O3', - '--use_fast_math'] + extra_cuda_flags, + extra_cuda_cflags=extra_cuda_cflags, + extra_include_paths=extra_include_paths, verbose=(args.rank == 0) ) - # '-gencode', 'arch=compute_70,code=sm_70', # ============== # Fused softmax. # ============== + if torch.version.hip is not None: + extra_include_paths=[os.path.abspath(srcpath)] + else: + extra_include_paths=[] + if args.masked_softmax_fusion: - extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', - '--expt-relaxed-constexpr', - '--expt-extended-lambda'] + if torch.version.hip is not None: + extra_cuda_flags = ['-D__HIP_NO_HALF_OPERATORS__=1', + '-D__HIP_NO_HALF_CONVERSIONS__=1'] + else: + extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__', + '--expt-relaxed-constexpr', + '--expt-extended-lambda'] # Upper triangular softmax. sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'] scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper( "scaled_upper_triang_masked_softmax_cuda", - sources, extra_cuda_flags) + sources, extra_cuda_flags, extra_include_paths) # Masked softmax. sources=[srcpath / 'scaled_masked_softmax.cpp', srcpath / 'scaled_masked_softmax_cuda.cu'] scaled_masked_softmax_cuda = _cpp_extention_load_helper( - "scaled_masked_softmax_cuda", sources, extra_cuda_flags) + "scaled_masked_softmax_cuda", sources, extra_cuda_flags, extra_include_paths) # ================================= # Mixed precision fused layer norm. # ================================= - extra_cuda_flags = ['-maxrregcount=50'] + if torch.version.hip is not None: + extra_cuda_flags = [] + else: + extra_cuda_flags = ['-maxrregcount=50'] + sources=[srcpath / 'layer_norm_cuda.cpp', srcpath / 'layer_norm_cuda_kernel.cu'] fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper( - "fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags) + "fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags, extra_include_paths) def _get_cuda_bare_metal_version(cuda_dir): diff --git a/megatron/fused_kernels/layer_norm_cuda_kernel.cu b/megatron/fused_kernels/layer_norm_cuda_kernel.cu index 28a579e1a..aae0c993c 100644 --- a/megatron/fused_kernels/layer_norm_cuda_kernel.cu +++ b/megatron/fused_kernels/layer_norm_cuda_kernel.cu @@ -76,7 +76,8 @@ void cuWelfordMuSigma2( const int i1, U& mu, U& sigma2, - U* buf) + U* buf, + const int GPU_WARP_SIZE) { // Assumptions: // 1) blockDim.x == warpSize @@ -106,12 +107,11 @@ void cuWelfordMuSigma2( cuWelfordOnlineSum(curr,mu,sigma2,count); } // intra-warp reductions - for (int l = 0; l <= 4; ++l) { - int srcLaneB = (threadIdx.x+(1<(muB,sigma2B,countB,mu,sigma2,count); + for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) { + U sigma2B = WARP_SHFL_DOWN(sigma2, stride); + U muB = WARP_SHFL_DOWN(mu, stride); + U countB = WARP_SHFL_DOWN(count, stride); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); } // threadIdx.x == 0 has correct values for each warp // inter-warp reductions @@ -160,7 +160,8 @@ void cuWelfordMuSigma2( const int i1, float& mu, float& sigma2, - float* buf) + float* buf, + const int GPU_WARP_SIZE) { // Assumptions: // 1) blockDim.x == warpSize @@ -201,12 +202,11 @@ void cuWelfordMuSigma2( cuWelfordOnlineSum(curr,mu,sigma2,count); } // intra-warp reductions - for (int l = 0; l <= 4; ++l) { - int srcLaneB = (threadIdx.x+(1< 0; stride /= 2) { + float sigma2B = WARP_SHFL_DOWN(sigma2, stride); + float muB = WARP_SHFL_DOWN(mu, stride); + float countB = WARP_SHFL_DOWN(count, stride); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); } // threadIdx.x == 0 has correct values for each warp // inter-warp reductions @@ -246,14 +246,25 @@ void cuWelfordMuSigma2( } } } - +#ifndef __HIP_PLATFORM_HCC__ template U rsqrt(U v) { +#else +template __device__ U rsqrt(U v) { +#endif return U(1) / sqrt(v); } +#ifndef __HIP_PLATFORM_HCC__ template<> float rsqrt(float v) { +#else +template<> __device__ float rsqrt(float v) { +#endif return rsqrtf(v); } +#ifndef __HIP_PLATFORM_HCC__ template<> double rsqrt(double v) { +#else +template<> __device__ double rsqrt(double v) { +#endif return rsqrt(v); } @@ -297,18 +308,23 @@ void cuApplyLayerNorm( const int n2, const U epsilon, const V* __restrict__ gamma, - const V* __restrict__ beta + const V* __restrict__ beta, + const int GPU_WARP_SIZE ) { // Assumptions: // 1) blockDim.x == warpSize // 2) Tensors are contiguous // +#ifndef __HIP_PLATFORM_HCC__ for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { +#else + for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { +#endif SharedMemory shared; U* buf = shared.getPointer(); U mu,sigma2; - cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf); + cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf,GPU_WARP_SIZE); const T* lvals = vals + i1*n2; V* ovals = output_vals + i1*n2; U c_invvar = rsqrt(sigma2 + epsilon); @@ -543,7 +559,11 @@ void cuComputeGradInput( const V* gamma, T* grad_input) { +#ifndef __HIP_PLATFORM_HCC__ for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { +#else + for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { +#endif U sum_loss1 = U(0); U sum_loss2 = U(0); const U c_mean = mean[i1]; @@ -667,7 +687,11 @@ void HostApplyLayerNorm( ) { auto stream = at::cuda::getCurrentCUDAStream().stream(); - const dim3 threads(32,4,1); + const int warp_size = at::cuda::warp_size(); + dim3 threads(warp_size,4,1); +#ifndef __HIP_PLATFORM_HCC__ + threads.y = 1; +#endif const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); @@ -682,7 +706,9 @@ void HostApplyLayerNorm( input, n1,n2, U(epsilon), - gamma,beta); + gamma, + beta, + warp_size); } @@ -735,11 +761,16 @@ void HostLayerNormGradient( ) { auto stream = at::cuda::getCurrentCUDAStream().stream(); + const int warp_size = at::cuda::warp_size(); if (gamma != NULL && beta != NULL) { // compute grad_gamma(j) and grad_beta(j) +#ifndef __HIP_PLATFORM_HCC__ + const int part_size = warp_size; +#else const int part_size = 16; - const dim3 threads2(32,4,1); +#endif + const dim3 threads2(warp_size,4,1); const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); @@ -758,7 +789,7 @@ void HostLayerNormGradient( part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR()); - const dim3 threads3(32,8,1); + const dim3 threads3(warp_size,8,1); const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); const int nshared3 = threads3.x * threads3.y * sizeof(U); cuComputeGradGammaBeta<<>>( @@ -774,7 +805,10 @@ void HostLayerNormGradient( const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); - const dim3 threads1(32,4,1); + dim3 threads1(warp_size,4,1); +#ifndef __HIP_PLATFORM_HCC__ + threads1.y = 2; +#endif int nshared = threads1.y > 1 ? threads1.y*threads1.x*sizeof(U) : diff --git a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu index 2efee39a6..0c068c7cb 100644 --- a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu +++ b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu @@ -18,7 +18,9 @@ #include #include #include +#ifndef __HIP_PLATFORM_HCC__ #include +#endif #include #include #include "scaled_masked_softmax.h" diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h index 6df83fc10..ee140c037 100644 --- a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h @@ -17,7 +17,8 @@ #pragma once #include -#include +#include +// #include #include #include #include diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu b/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu index 5efc3d412..59e452584 100644 --- a/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu @@ -18,7 +18,9 @@ #include #include #include +#ifndef __HIP_PLATFORM_HCC__ #include +#endif #include #include #include "scaled_upper_triang_masked_softmax.h" diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 55e9c9dd8..7b1d7eaa7 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -86,10 +86,13 @@ def __init__(self, normalized_shape, eps=1e-5): args = get_args() self.layernorm_tp_auto_sync = args.sync_tp_duplicated_parameters - self.use_meg_ds_fused_layer_norm = ( - args.bf16 # Current Meg-DS cuda kernel has better throughput than torch.nn.LayerNorm - or version.parse(torch.__version__) >= version.parse("1.11.0") # https://github.com/pytorch/pytorch/pull/66920 - ) + if not args.layer_norm_fusion: + self.use_meg_ds_fused_layer_norm = False + else: + self.use_meg_ds_fused_layer_norm = ( + args.bf16 # Current Meg-DS cuda kernel has better throughput than torch.nn.LayerNorm + or version.parse(torch.__version__) >= version.parse("1.11.0") # https://github.com/pytorch/pytorch/pull/66920 + ) def reset_parameters(self): diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py index 738717d55..b492abfb4 100644 --- a/megatron/optimizer/__init__.py +++ b/megatron/optimizer/__init__.py @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from apex.optimizers import FusedAdam as Adam -from apex.optimizers import FusedSGD as SGD +from torch.optim import AdamW +from torch.optim import SGD +from apex.optimizers import FusedAdam +from apex.optimizers import FusedSGD from megatron import get_args from megatron.model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm @@ -72,18 +74,24 @@ def get_megatron_optimizer(model): if args.use_bnb_optimizer: import bitsandbytes as bnb adam_optimizer = bnb.optim.Adam8bit + elif args.optimizer_fusion: + adam_optimizer = FusedAdam else: - adam_optimizer = Adam + adam_optimizer = AdamW optimizer = adam_optimizer(param_groups, lr=args.lr, weight_decay=args.weight_decay, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps) elif args.optimizer == 'sgd': - optimizer = SGD(param_groups, - lr=args.lr, - weight_decay=args.weight_decay, - momentum=args.sgd_momentum) + if args.optimizer_fusion: + sgd_optimizer = FusedSGD + else: + sgd_optimizer = SGD + optimizer = sgd_optimizer(param_groups, + lr=args.lr, + weight_decay=args.weight_decay, + momentum=args.sgd_momentum) else: raise Exception('{} optimizer is not supported.'.format( args.optimizer)) diff --git a/tasks/eval_harness/download.py b/tasks/eval_harness/download.py index d2abcd83a..be5d5c303 100644 --- a/tasks/eval_harness/download.py +++ b/tasks/eval_harness/download.py @@ -1,6 +1,7 @@ # Downloads the specified taks in the evaluation harness # This is particularly useful when running in environments where the GPU nodes # do not have internet access. This way we can pre-download them and use the cached data-set during evaluation. +# May want to set a cache before, e.g. export HF_DATASETS_CACHE=/scratch/project_462000119/ds_cache from lm_eval import tasks from lm_eval.tasks import ALL_TASKS diff --git a/tasks/eval_harness/evaluate.py b/tasks/eval_harness/evaluate.py index 68dd649fd..7e0cdb60e 100644 --- a/tasks/eval_harness/evaluate.py +++ b/tasks/eval_harness/evaluate.py @@ -260,7 +260,7 @@ def tokenizer_encode(self, text): from megatron.initialize import initialize_megatron import megatron -from tools.convert_checkpoint.deepspeed_checkpoint import DeepSpeedCheckpoint +from deepspeed.checkpoint.deepspeed_checkpoint import DeepSpeedCheckpoint from tools.convert_checkpoint.deepspeed_to_megatron import _create_rank_checkpoint def override_args(args, override_args, skip_keys, skip_if_specified_keys): diff --git a/tools/convert_checkpoint/deepspeed_to_megatron.py b/tools/convert_checkpoint/deepspeed_to_megatron.py index 74e5ca7c9..08471d0bb 100755 --- a/tools/convert_checkpoint/deepspeed_to_megatron.py +++ b/tools/convert_checkpoint/deepspeed_to_megatron.py @@ -4,7 +4,7 @@ import os import torch from collections import OrderedDict -from .deepspeed_checkpoint import ARGS_KEY, DeepSpeedCheckpoint +from deepspeed.checkpoint.deepspeed_checkpoint import DeepSpeedCheckpoint MODEL_KEY = 'model' ARGS_KEY = 'args'