Browse Source

more robust model file search

master
Hendrik Langer 2 years ago
parent
commit
bd2c0aa043
  1. 12
      runpod/runpod-worker-transformers/runpod_infer.py

12
runpod/runpod-worker-transformers/runpod_infer.py

@ -139,10 +139,10 @@ def load_quantized(model_name, wbits, groupsize, device):
found_safetensors = list(path_to_model.glob("*.safetensors")) found_safetensors = list(path_to_model.glob("*.safetensors"))
pt_path = None pt_path = None
if len(found_pts) == 1: if len(found_pts) > 0:
pt_path = found_pts[0] pt_path = found_pts[-1]
elif len(found_safetensors) == 1: elif len(found_safetensors) > 0:
pt_path = found_safetensors[0] pt_path = found_safetensors[-1]
else: else:
pass pass
@ -317,8 +317,8 @@ if __name__ == "__main__":
path_to_model = next( Path(f'/root/.cache/huggingface/hub/').glob("models--*/snapshots/*/") ) path_to_model = next( Path(f'/root/.cache/huggingface/hub/').glob("models--*/snapshots/*/") )
found_pths = list(path_to_model.glob("*.pth")) found_pths = list(path_to_model.glob("*.pth"))
pt_path = None pt_path = None
if len(found_pths) == 1: if len(found_pths) > 0:
pt_path = found_pts[0] pt_path = found_pths[-1]
else: else:
print("Could not find the model, exiting...") print("Could not find the model, exiting...")
exit() exit()

Loading…
Cancel
Save