''' infer.py for runpod worker ''' import os import predict import runpod from runpod.serverless.utils import rp_download, rp_upload, rp_cleanup from runpod.serverless.utils.rp_validator import validate INPUT_SCHEMA = { 'prompt': { 'type': str, 'required': True }, 'negative_prompt': { 'type': str, 'required': False, 'default': None }, 'width': { 'type': int, 'required': False, 'default': 512, 'constraints': lambda width: width in [128, 256, 384, 448, 512, 576, 640, 704, 768] }, 'height': { 'type': int, 'required': False, 'default': 512, 'constraints': lambda height: height in [128, 256, 384, 448, 512, 576, 640, 704, 768] }, 'init_image': { 'type': str, 'required': False, 'default': None }, 'mask': { 'type': str, 'required': False, 'default': None }, 'prompt_strength': { 'type': float, 'required': False, 'default': 0.8, 'constraints': lambda prompt_strength: 0 <= prompt_strength <= 1 }, 'num_outputs': { 'type': int, 'required': False, 'default': 1, 'constraints': lambda num_outputs: 10 > num_outputs > 0 }, 'num_inference_steps': { 'type': int, 'required': False, 'default': 50, 'constraints': lambda num_inference_steps: 0 < num_inference_steps < 500 }, 'guidance_scale': { 'type': float, 'required': False, 'default': 7.5, 'constraints': lambda guidance_scale: 0 < guidance_scale < 20 }, 'scheduler': { 'type': str, 'required': False, 'default': 'K-LMS', 'constraints': lambda scheduler: scheduler in ['DDIM', 'DDPM', 'DPM-M', 'DPM-S', 'EULER-A', 'EULER-D', 'HEUN', 'IPNDM', 'KDPM2-A', 'KDPM2-D', 'PNDM', 'K-LMS'] }, 'seed': { 'type': int, 'required': False, 'default': None }, 'nsfw': { 'type': bool, 'required': False, 'default': False } } def run(job): ''' Run inference on the model. Returns output path, width the seed used to generate the image. ''' job_input = job['input'] # Input validation validated_input = validate(job_input, INPUT_SCHEMA) if 'errors' in validated_input: return {"error": validated_input['errors']} validated_input = validated_input['validated_input'] # Download input objects job_input['init_image'], job_input['mask'] = rp_download.download_input_objects( [job_input.get('init_image', None), job_input.get('mask', None)] ) # pylint: disable=unbalanced-tuple-unpacking MODEL.NSFW = job_input.get('nsfw', True) if job_input['seed'] is None: job_input['seed'] = int.from_bytes(os.urandom(2), "big") img_paths = MODEL.predict( prompt=job_input["prompt"], negative_prompt=job_input.get("negative_prompt", None), width=job_input.get('width', 512), height=job_input.get('height', 512), init_image=job_input['init_image'], mask=job_input['mask'], prompt_strength=job_input['prompt_strength'], num_outputs=job_input.get('num_outputs', 1), num_inference_steps=job_input.get('num_inference_steps', 50), guidance_scale=job_input['guidance_scale'], scheduler=job_input.get('scheduler', "K-LMS"), seed=job_input['seed'] ) job_output = [] for index, img_path in enumerate(img_paths): image_url = rp_upload.upload_image(job['id'], img_path, index) job_output.append({ "image": image_url, "seed": job_input['seed'] + index }) # Remove downloaded input objects rp_cleanup.clean(['input_objects']) return job_output parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--model_url", type=str, default=None, help="Model URL") if __name__ == "__main__": args = parser.parse_args() print(args) if "huggingface.co" in args.model_url: url_parts = args.model_url.split("/") model_id = f"{url_parts[-2]}/{url_parts[-1]}" else: model_id = f"./{MODEL_CACHE_DIR}/model" MODEL = predict.Predictor(model_id) MODEL.setup() runpod.serverless.start({"handler": run})