From d78181974ba1f9b34505ad1d5002e5199854cf91 Mon Sep 17 00:00:00 2001 From: Hendrik Langer Date: Wed, 5 Apr 2023 20:05:40 +0200 Subject: [PATCH] chatbot remote worker test --- runpod/runpod-worker-transformers/Dockerfile | 7 +++++-- .../runpod-worker-transformers/model_fetcher.py | 15 ++++++++------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/runpod/runpod-worker-transformers/Dockerfile b/runpod/runpod-worker-transformers/Dockerfile index c352b9f..3b8369a 100644 --- a/runpod/runpod-worker-transformers/Dockerfile +++ b/runpod/runpod-worker-transformers/Dockerfile @@ -1,6 +1,9 @@ ARG BASE_IMAGE=nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04 FROM ${BASE_IMAGE} as dev-base +ARG MODEL_NAME +ENV MODEL_NAME=${MODEL_NAME} + WORKDIR / SHELL ["/bin/bash", "-o", "pipefail", "-c"] ENV DEBIAN_FRONTEND noninteractive\ @@ -44,11 +47,11 @@ RUN mkdir /workspace WORKDIR /workspace COPY model_fetcher.py /workspace/ -RUN python model_fetcher.py --model_name=${MODEL_NAME} +RUN python3 model_fetcher.py --model_name=${MODEL_NAME} #RUN git lfs install && \ # git clone --depth 1 https://huggingface.co/${MODEL_NAME} COPY runpod_infer.py /workspace/ COPY test_input.json /workspace/ -CMD python -u runpod_infer.py --model_name=${MODEL_NAME} +CMD python3 -u runpod_infer.py --model_name=${MODEL_NAME} diff --git a/runpod/runpod-worker-transformers/model_fetcher.py b/runpod/runpod-worker-transformers/model_fetcher.py index 3758a9e..62c502e 100644 --- a/runpod/runpod-worker-transformers/model_fetcher.py +++ b/runpod/runpod-worker-transformers/model_fetcher.py @@ -7,7 +7,7 @@ import argparse import torch from transformers import (GPTNeoForCausalLM, GPT2Tokenizer, GPTNeoXForCausalLM, GPTNeoXTokenizerFast, GPTJForCausalLM, AutoTokenizer, AutoModelForCausalLM) - +from huggingface_hub import snapshot_download def download_model(model_name): @@ -28,8 +28,9 @@ def download_model(model_name): # --------------------------------- Pygmalion -------------------------------- # elif model_name == 'pygmalion-6b': - AutoModelForCausalLM.from_pretrained("PygmalionAI/pygmalion-6b") - AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b") +# AutoModelForCausalLM.from_pretrained("PygmalionAI/pygmalion-6b", load_in_8bit=True) +# AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b") + snapshot_download(repo_id="PygmalionAI/pygmalion-6b", revision="main") # ----------------------------------- GPT-J ----------------------------------- # elif model_name == 'gpt-j-6b': @@ -39,17 +40,17 @@ def download_model(model_name): # ------------------------------ PPO Shygmalion 6B ----------------------------- # elif model_name == 'ppo-shygmalion-6b': - AutoModelForCausalLM.from_pretrained("TehVenom/PPO_Shygmalion-6b") + AutoModelForCausalLM.from_pretrained("TehVenom/PPO_Shygmalion-6b", load_in_8bit=True) AutoTokenizer.from_pretrained("TehVenom/PPO_Shygmalion-6b") # ------------------------------ Dolly Shygmalion 6B ----------------------------- # elif model_name == 'dolly-shygmalion-6b': - AutoModelForCausalLM.from_pretrained("TehVenom/Dolly_Shygmalion-6b") + AutoModelForCausalLM.from_pretrained("TehVenom/Dolly_Shygmalion-6b", load_in_8bit=True) AutoTokenizer.from_pretrained("TehVenom/Dolly_Shygmalion-6b") # ------------------------------ Erebus 13B (NSFW) ----------------------------- # elif model_name == 'erebus-13b': - AutoModelForCausalLM.from_pretrained("KoboldAI/OPT-13B-Erebus") + AutoModelForCausalLM.from_pretrained("KoboldAI/OPT-13B-Erebus", load_in_8bit=True) AutoTokenizer.from_pretrained("KoboldAI/OPT-13B-Erebus") # --------------------------- Alpaca 13B (Quantized) -------------------------- # @@ -59,7 +60,7 @@ def download_model(model_name): # --------------------------------- Alpaca 13B -------------------------------- # elif model_name == 'gpt4-x-alpaca': - AutoModelForCausalLM.from_pretrained("chavinlo/gpt4-x-alpaca") + AutoModelForCausalLM.from_pretrained("chavinlo/gpt4-x-alpaca", load_in_8bit=True) AutoTokenizer.from_pretrained("chavinlo/gpt4-x-alpaca")