diff --git a/runpod/runpod-worker-transformers/runpod_infer.py b/runpod/runpod-worker-transformers/runpod_infer.py index a04370e..59216d9 100644 --- a/runpod/runpod-worker-transformers/runpod_infer.py +++ b/runpod/runpod-worker-transformers/runpod_infer.py @@ -139,10 +139,10 @@ def load_quantized(model_name, wbits, groupsize, device): found_safetensors = list(path_to_model.glob("*.safetensors")) pt_path = None - if len(found_pts) == 1: - pt_path = found_pts[0] - elif len(found_safetensors) == 1: - pt_path = found_safetensors[0] + if len(found_pts) > 0: + pt_path = found_pts[-1] + elif len(found_safetensors) > 0: + pt_path = found_safetensors[-1] else: pass @@ -317,8 +317,8 @@ if __name__ == "__main__": path_to_model = next( Path(f'/root/.cache/huggingface/hub/').glob("models--*/snapshots/*/") ) found_pths = list(path_to_model.glob("*.pth")) pt_path = None - if len(found_pths) == 1: - pt_path = found_pts[0] + if len(found_pths) > 0: + pt_path = found_pths[-1] else: print("Could not find the model, exiting...") exit()