diff --git a/dask_ec2/cli/main.py b/dask_ec2/cli/main.py index e796863..f0c528f 100644 --- a/dask_ec2/cli/main.py +++ b/dask_ec2/cli/main.py @@ -74,7 +74,9 @@ def cli(ctx): show_default=True, required=False, help="EC2 Instance Type") -@click.option("--count", default=4, show_default=True, required=False, help="Number of nodes") +@click.option("--count", default=4, show_default=True, required=False, help="Number of nodes (at least one recommended for scheduler)") +@click.option("--spot-count", default=0, show_default=True, required=False, help="Number of spot instance nodes") +@click.option("--spot-price", default="0.10", show_default=True, required=False, help="Maximum spot price for the spot instances") @click.option("--security-group", "security_group_name", default="dask-ec2-default", @@ -138,7 +140,7 @@ def cli(ctx): show_default=True, help="Install Dask/Distributed from git master") def up(ctx, name, keyname, keypair, region_name, vpc_id, subnet_id, - iaminstance_name, ami, username, instance_type, count, + iaminstance_name, ami, username, instance_type, count, spot_count, spot_price, security_group_name, security_group_id, volume_type, volume_size, filepath, _provision, anaconda_, dask, notebook, nprocs, source): import os @@ -158,6 +160,8 @@ def up(ctx, name, keyname, keypair, region_name, vpc_id, subnet_id, image_id=ami, instance_type=instance_type, count=count, + spot_count = spot_count, + spot_price = spot_price, keyname=keyname, security_group_name=security_group_name, security_group_id=security_group_id, diff --git a/dask_ec2/ec2.py b/dask_ec2/ec2.py index 397a285..e4c074b 100644 --- a/dask_ec2/ec2.py +++ b/dask_ec2/ec2.py @@ -2,6 +2,7 @@ import time import logging +import copy import boto3 from botocore.exceptions import ClientError, WaiterError @@ -207,8 +208,35 @@ def check_image_is_ebs(self, image_id): if root_type != "ebs": raise DaskEc2Exception("The AMI {} Root Device Type is not EBS. Only EBS Root Device AMI are supported.".format( image_id)) + + def wait_for_fulfillment(self, request_ids, pending_request_ids): + """Loop through all pending request ids waiting for them to be fulfilled. + If a request is fulfilled, remove it from pending_request_ids. + If there are still pending requests, sleep and check again in 10 seconds. + Only return when all spot requests have been fulfilled.""" + results = self.client.describe_spot_instance_requests(SpotInstanceRequestIds=pending_request_ids) + for result in results['SpotInstanceRequests']: + rid = result['SpotInstanceRequestId'] + if result['Status']['Code'] == 'fulfilled': + pending_request_ids.pop(pending_request_ids.index(rid)) + logger.debug("spot request `{}` fulfilled!".format(rid)) + else: + logger.debug("waiting on `{}`".format(rid)) + + if len(pending_request_ids) == 0: + logger.debug("all spots fulfilled!") + #return the instance ids + response = self.client.describe_spot_instance_requests(SpotInstanceRequestIds=request_ids)['SpotInstanceRequests'] + instances = [i['InstanceId'] for i in response] + return instances + else: + time.sleep(10) + return self.wait_for_fulfillment(request_ids, pending_request_ids) + def launch(self, name, image_id, instance_type, count, keyname, + spot_count = 0, + spot_price = "0.10", security_group_name=DEFAULT_SG_GROUP_NAME, security_group_id=None, volume_type="gp2", @@ -238,20 +266,32 @@ def launch(self, name, image_id, instance_type, count, keyname, else: security_groups_ids = self.get_security_groups_ids(security_group_name) - logger.debug("Creating %i instances on EC2", count) kwargs = dict(ImageId=image_id, KeyName=keyname, - MinCount=count, - MaxCount=count, InstanceType=instance_type, SecurityGroupIds=self.get_security_groups_ids(security_group_name), BlockDeviceMappings=device_map) + if self.subnet_id is not None and self.subnet_id != "": + kwargs['SubnetId'] = self.subnet_id + if self.iaminstance_name is not None and self.iaminstance_name != "": + kwargs['IamInstanceProfile'] = {'Name': self.iaminstance_name} + if spot_count: + logger.debug("Creating %i spot instances on EC2", spot_count) + response = self.client.request_spot_instances(SpotPrice=spot_price, + InstanceCount = spot_count, + Type='one-time', + LaunchSpecification=kwargs) + request_ids = [r['SpotInstanceRequestId'] for r in response['SpotInstanceRequests']] + else: + request_ids = None + logger.debug("Creating %i instances on EC2", count) + kwargs ['MinCount']=count + kwargs['MaxCount']=count if self.subnet_id is not None and self.subnet_id != "": kwargs['SubnetId'] = self.subnet_id if self.iaminstance_name is not None and self.iaminstance_name != "": kwargs['IamInstanceProfile'] = {'Name': self.iaminstance_name} instances = self.ec2.create_instances(**kwargs) - time.sleep(5) ids = [i.id for i in instances] @@ -261,8 +301,11 @@ def launch(self, name, image_id, instance_type, count, keyname, except WaiterError: raise DaskEc2Exception( "An unexpected error occurred when launching the requested instances. Refer to the AWS Management Console for more information.") - - collection = self.ec2.instances.filter(InstanceIds=ids) + if request_ids: + spot_instances_ids = self.wait_for_fulfillment(request_ids, copy.deepcopy(request_ids)) + else: + spot_instances_ids = [] + collection = self.ec2.instances.filter(InstanceIds=ids + spot_instances_ids) instances = [] for i, instance in enumerate(collection): instances.append(instance)