You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
157 lines
4.3 KiB
157 lines
4.3 KiB
''' infer.py for runpod worker '''
|
|
|
|
import os
|
|
import predict
|
|
|
|
import runpod
|
|
from runpod.serverless.utils import rp_download, rp_upload, rp_cleanup
|
|
from runpod.serverless.utils.rp_validator import validate
|
|
|
|
|
|
INPUT_SCHEMA = {
|
|
'prompt': {
|
|
'type': str,
|
|
'required': True
|
|
},
|
|
'negative_prompt': {
|
|
'type': str,
|
|
'required': False,
|
|
'default': None
|
|
},
|
|
'width': {
|
|
'type': int,
|
|
'required': False,
|
|
'default': 512,
|
|
'constraints': lambda width: width in [128, 256, 384, 448, 512, 576, 640, 704, 768]
|
|
},
|
|
'height': {
|
|
'type': int,
|
|
'required': False,
|
|
'default': 512,
|
|
'constraints': lambda height: height in [128, 256, 384, 448, 512, 576, 640, 704, 768]
|
|
},
|
|
'init_image': {
|
|
'type': str,
|
|
'required': False,
|
|
'default': None
|
|
},
|
|
'mask': {
|
|
'type': str,
|
|
'required': False,
|
|
'default': None
|
|
},
|
|
'prompt_strength': {
|
|
'type': float,
|
|
'required': False,
|
|
'default': 0.8,
|
|
'constraints': lambda prompt_strength: 0 <= prompt_strength <= 1
|
|
},
|
|
'num_outputs': {
|
|
'type': int,
|
|
'required': False,
|
|
'default': 1,
|
|
'constraints': lambda num_outputs: 10 > num_outputs > 0
|
|
},
|
|
'num_inference_steps': {
|
|
'type': int,
|
|
'required': False,
|
|
'default': 50,
|
|
'constraints': lambda num_inference_steps: 0 < num_inference_steps < 500
|
|
},
|
|
'guidance_scale': {
|
|
'type': float,
|
|
'required': False,
|
|
'default': 7.5,
|
|
'constraints': lambda guidance_scale: 0 < guidance_scale < 20
|
|
},
|
|
'scheduler': {
|
|
'type': str,
|
|
'required': False,
|
|
'default': 'K-LMS',
|
|
'constraints': lambda scheduler: scheduler in ['DDIM', 'DDPM', 'DPM-M', 'DPM-S', 'EULER-A', 'EULER-D', 'HEUN', 'IPNDM', 'KDPM2-A', 'KDPM2-D', 'PNDM', 'K-LMS']
|
|
},
|
|
'seed': {
|
|
'type': int,
|
|
'required': False,
|
|
'default': None
|
|
},
|
|
'nsfw': {
|
|
'type': bool,
|
|
'required': False,
|
|
'default': False
|
|
}
|
|
}
|
|
|
|
|
|
def run(job):
|
|
'''
|
|
Run inference on the model.
|
|
Returns output path, width the seed used to generate the image.
|
|
'''
|
|
job_input = job['input']
|
|
|
|
# Input validation
|
|
validated_input = validate(job_input, INPUT_SCHEMA)
|
|
|
|
if 'errors' in validated_input:
|
|
return {"error": validated_input['errors']}
|
|
validated_input = validated_input['validated_input']
|
|
|
|
# Download input objects
|
|
job_input['init_image'], job_input['mask'] = rp_download.download_input_objects(
|
|
[job_input.get('init_image', None), job_input.get('mask', None)]
|
|
) # pylint: disable=unbalanced-tuple-unpacking
|
|
|
|
MODEL.NSFW = job_input.get('nsfw', True)
|
|
|
|
if job_input['seed'] is None:
|
|
job_input['seed'] = int.from_bytes(os.urandom(2), "big")
|
|
|
|
img_paths = MODEL.predict(
|
|
prompt=job_input["prompt"],
|
|
negative_prompt=job_input.get("negative_prompt", None),
|
|
width=job_input.get('width', 512),
|
|
height=job_input.get('height', 512),
|
|
init_image=job_input['init_image'],
|
|
mask=job_input['mask'],
|
|
prompt_strength=job_input['prompt_strength'],
|
|
num_outputs=job_input.get('num_outputs', 1),
|
|
num_inference_steps=job_input.get('num_inference_steps', 50),
|
|
guidance_scale=job_input['guidance_scale'],
|
|
scheduler=job_input.get('scheduler', "K-LMS"),
|
|
seed=job_input['seed']
|
|
)
|
|
|
|
job_output = []
|
|
for index, img_path in enumerate(img_paths):
|
|
image_url = rp_upload.upload_image(job['id'], img_path, index)
|
|
|
|
job_output.append({
|
|
"image": image_url,
|
|
"seed": job_input['seed'] + index
|
|
})
|
|
|
|
# Remove downloaded input objects
|
|
rp_cleanup.clean(['input_objects'])
|
|
|
|
return job_output
|
|
|
|
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument("--model_url", type=str,
|
|
default=None, help="Model URL")
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
print(args)
|
|
|
|
if "huggingface.co" in args.model_url:
|
|
url_parts = args.model_url.split("/")
|
|
model_id = f"{url_parts[-2]}/{url_parts[-1]}"
|
|
else:
|
|
model_id = f"model.safetensors"
|
|
|
|
MODEL = predict.Predictor(model_id)
|
|
MODEL.setup()
|
|
|
|
runpod.serverless.start({"handler": run})
|
|
|