From bd2c0aa0430cac618f92f6de529dc3114d54250e Mon Sep 17 00:00:00 2001 From: Hendrik Langer Date: Wed, 12 Apr 2023 17:43:18 +0200 Subject: [PATCH] more robust model file search --- runpod/runpod-worker-transformers/runpod_infer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/runpod/runpod-worker-transformers/runpod_infer.py b/runpod/runpod-worker-transformers/runpod_infer.py index a04370e..59216d9 100644 --- a/runpod/runpod-worker-transformers/runpod_infer.py +++ b/runpod/runpod-worker-transformers/runpod_infer.py @@ -139,10 +139,10 @@ def load_quantized(model_name, wbits, groupsize, device): found_safetensors = list(path_to_model.glob("*.safetensors")) pt_path = None - if len(found_pts) == 1: - pt_path = found_pts[0] - elif len(found_safetensors) == 1: - pt_path = found_safetensors[0] + if len(found_pts) > 0: + pt_path = found_pts[-1] + elif len(found_safetensors) > 0: + pt_path = found_safetensors[-1] else: pass @@ -317,8 +317,8 @@ if __name__ == "__main__": path_to_model = next( Path(f'/root/.cache/huggingface/hub/').glob("models--*/snapshots/*/") ) found_pths = list(path_to_model.glob("*.pth")) pt_path = None - if len(found_pths) == 1: - pt_path = found_pts[0] + if len(found_pths) > 0: + pt_path = found_pths[-1] else: print("Could not find the model, exiting...") exit()