Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

158 lines
4.3 KiB

''' infer.py for runpod worker '''
import os
import predict
import runpod
from runpod.serverless.utils import rp_download, rp_upload, rp_cleanup
from runpod.serverless.utils.rp_validator import validate
INPUT_SCHEMA = {
'prompt': {
'type': str,
'required': True
},
'negative_prompt': {
'type': str,
'required': False,
'default': None
},
'width': {
'type': int,
'required': False,
'default': 512,
'constraints': lambda width: width in [128, 256, 384, 448, 512, 576, 640, 704, 768]
},
'height': {
'type': int,
'required': False,
'default': 512,
'constraints': lambda height: height in [128, 256, 384, 448, 512, 576, 640, 704, 768]
},
'init_image': {
'type': str,
'required': False,
'default': None
},
'mask': {
'type': str,
'required': False,
'default': None
},
'prompt_strength': {
'type': float,
'required': False,
'default': 0.8,
'constraints': lambda prompt_strength: 0 <= prompt_strength <= 1
},
'num_outputs': {
'type': int,
'required': False,
'default': 1,
'constraints': lambda num_outputs: 10 > num_outputs > 0
},
'num_inference_steps': {
'type': int,
'required': False,
'default': 50,
'constraints': lambda num_inference_steps: 0 < num_inference_steps < 500
},
'guidance_scale': {
'type': float,
'required': False,
'default': 7.5,
'constraints': lambda guidance_scale: 0 < guidance_scale < 20
},
'scheduler': {
'type': str,
'required': False,
'default': 'K-LMS',
'constraints': lambda scheduler: scheduler in ['DDIM', 'DDPM', 'DPM-M', 'DPM-S', 'EULER-A', 'EULER-D', 'HEUN', 'IPNDM', 'KDPM2-A', 'KDPM2-D', 'PNDM', 'K-LMS']
},
'seed': {
'type': int,
'required': False,
'default': None
},
'nsfw': {
'type': bool,
'required': False,
'default': False
}
}
def run(job):
'''
Run inference on the model.
Returns output path, width the seed used to generate the image.
'''
job_input = job['input']
# Input validation
validated_input = validate(job_input, INPUT_SCHEMA)
if 'errors' in validated_input:
return {"error": validated_input['errors']}
validated_input = validated_input['validated_input']
# Download input objects
job_input['init_image'], job_input['mask'] = rp_download.download_input_objects(
[job_input.get('init_image', None), job_input.get('mask', None)]
) # pylint: disable=unbalanced-tuple-unpacking
MODEL.NSFW = job_input.get('nsfw', True)
if job_input['seed'] is None:
job_input['seed'] = int.from_bytes(os.urandom(2), "big")
img_paths = MODEL.predict(
prompt=job_input["prompt"],
negative_prompt=job_input.get("negative_prompt", None),
width=job_input.get('width', 512),
height=job_input.get('height', 512),
init_image=job_input['init_image'],
mask=job_input['mask'],
prompt_strength=job_input['prompt_strength'],
num_outputs=job_input.get('num_outputs', 1),
num_inference_steps=job_input.get('num_inference_steps', 50),
guidance_scale=job_input['guidance_scale'],
scheduler=job_input.get('scheduler', "K-LMS"),
seed=job_input['seed']
)
job_output = []
for index, img_path in enumerate(img_paths):
image_url = rp_upload.upload_image(job['id'], img_path, index)
job_output.append({
"image": image_url,
"seed": job_input['seed'] + index
})
# Remove downloaded input objects
rp_cleanup.clean(['input_objects'])
return job_output
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--model_url", type=str,
default=None, help="Model URL")
if __name__ == "__main__":
args = parser.parse_args()
print(args)
if "huggingface.co" in args.model_url:
url_parts = args.model_url.split("/")
model_id = f"{url_parts[-2]}/{url_parts[-1]}"
else:
model_id = f"model.safetensors"
MODEL = predict.Predictor(model_id)
MODEL.setup()
runpod.serverless.start({"handler": run})