Skip to content
This repository was archived by the owner on Nov 3, 2020. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions dask_ec2/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def cli(ctx):
show_default=True,
required=False,
help="EC2 Instance Type")
@click.option("--count", default=4, show_default=True, required=False, help="Number of nodes")
@click.option("--count", default=4, show_default=True, required=False, help="Number of nodes (at least one recommended for scheduler)")
@click.option("--spot-count", default=0, show_default=True, required=False, help="Number of spot instance nodes")
@click.option("--spot-price", default="0.10", show_default=True, required=False, help="Maximum spot price for the spot instances")
@click.option("--security-group",
"security_group_name",
default="dask-ec2-default",
Expand Down Expand Up @@ -138,7 +140,7 @@ def cli(ctx):
show_default=True,
help="Install Dask/Distributed from git master")
def up(ctx, name, keyname, keypair, region_name, vpc_id, subnet_id,
iaminstance_name, ami, username, instance_type, count,
iaminstance_name, ami, username, instance_type, count, spot_count, spot_price,
security_group_name, security_group_id, volume_type, volume_size, filepath, _provision, anaconda_,
dask, notebook, nprocs, source):
import os
Expand All @@ -158,6 +160,8 @@ def up(ctx, name, keyname, keypair, region_name, vpc_id, subnet_id,
image_id=ami,
instance_type=instance_type,
count=count,
spot_count = spot_count,
spot_price = spot_price,
keyname=keyname,
security_group_name=security_group_name,
security_group_id=security_group_id,
Expand Down
55 changes: 49 additions & 6 deletions dask_ec2/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import time
import logging
import copy

import boto3
from botocore.exceptions import ClientError, WaiterError
Expand Down Expand Up @@ -207,8 +208,35 @@ def check_image_is_ebs(self, image_id):
if root_type != "ebs":
raise DaskEc2Exception("The AMI {} Root Device Type is not EBS. Only EBS Root Device AMI are supported.".format(
image_id))

def wait_for_fulfillment(self, request_ids, pending_request_ids):
"""Loop through all pending request ids waiting for them to be fulfilled.
If a request is fulfilled, remove it from pending_request_ids.
If there are still pending requests, sleep and check again in 10 seconds.
Only return when all spot requests have been fulfilled."""
results = self.client.describe_spot_instance_requests(SpotInstanceRequestIds=pending_request_ids)
for result in results['SpotInstanceRequests']:
rid = result['SpotInstanceRequestId']
if result['Status']['Code'] == 'fulfilled':
pending_request_ids.pop(pending_request_ids.index(rid))
logger.debug("spot request `{}` fulfilled!".format(rid))
else:
logger.debug("waiting on `{}`".format(rid))

if len(pending_request_ids) == 0:
logger.debug("all spots fulfilled!")
#return the instance ids
response = self.client.describe_spot_instance_requests(SpotInstanceRequestIds=request_ids)['SpotInstanceRequests']
instances = [i['InstanceId'] for i in response]
return instances
else:
time.sleep(10)
return self.wait_for_fulfillment(request_ids, pending_request_ids)


def launch(self, name, image_id, instance_type, count, keyname,
spot_count = 0,
spot_price = "0.10",
security_group_name=DEFAULT_SG_GROUP_NAME,
security_group_id=None,
volume_type="gp2",
Expand Down Expand Up @@ -238,20 +266,32 @@ def launch(self, name, image_id, instance_type, count, keyname,
else:
security_groups_ids = self.get_security_groups_ids(security_group_name)

logger.debug("Creating %i instances on EC2", count)
kwargs = dict(ImageId=image_id,
KeyName=keyname,
MinCount=count,
MaxCount=count,
InstanceType=instance_type,
SecurityGroupIds=self.get_security_groups_ids(security_group_name),
BlockDeviceMappings=device_map)
if self.subnet_id is not None and self.subnet_id != "":
kwargs['SubnetId'] = self.subnet_id
if self.iaminstance_name is not None and self.iaminstance_name != "":
kwargs['IamInstanceProfile'] = {'Name': self.iaminstance_name}
if spot_count:
logger.debug("Creating %i spot instances on EC2", spot_count)
response = self.client.request_spot_instances(SpotPrice=spot_price,
InstanceCount = spot_count,
Type='one-time',
LaunchSpecification=kwargs)
request_ids = [r['SpotInstanceRequestId'] for r in response['SpotInstanceRequests']]
else:
request_ids = None
logger.debug("Creating %i instances on EC2", count)
kwargs ['MinCount']=count
kwargs['MaxCount']=count
if self.subnet_id is not None and self.subnet_id != "":
kwargs['SubnetId'] = self.subnet_id
if self.iaminstance_name is not None and self.iaminstance_name != "":
kwargs['IamInstanceProfile'] = {'Name': self.iaminstance_name}
instances = self.ec2.create_instances(**kwargs)

time.sleep(5)

ids = [i.id for i in instances]
Expand All @@ -261,8 +301,11 @@ def launch(self, name, image_id, instance_type, count, keyname,
except WaiterError:
raise DaskEc2Exception(
"An unexpected error occurred when launching the requested instances. Refer to the AWS Management Console for more information.")

collection = self.ec2.instances.filter(InstanceIds=ids)
if request_ids:
spot_instances_ids = self.wait_for_fulfillment(request_ids, copy.deepcopy(request_ids))
else:
spot_instances_ids = []
collection = self.ec2.instances.filter(InstanceIds=ids + spot_instances_ids)
instances = []
for i, instance in enumerate(collection):
instances.append(instance)
Expand Down