Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions project/config/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ azure-keyvault
azure-mgmt-common
azure-mgmt-resource
azure-mgmt-compute
azure-mgmt-monitor
black
six
pytz>=2018.7
Expand Down
24 changes: 17 additions & 7 deletions project/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ def validate_keys_for_user(userdata, config_map, username, keys_to_delete):
return
user_data = userdata.get(username_to_validate)
if user_data.get("plugins"):
if user_data.get("plugins")[0].get("iam"):
if "get_new_key" in user_data.get("plugins")[0].get("iam")[0]:
iam_plugin = user_data.get("plugins")[0].get("iam")
if iam_plugin:
if (
"get_new_key" in iam_plugin[0]
or "rotate_ses_smtp_user" in iam_plugin[0]
):
validation_result = validate_new_key(
config_map, username_to_validate, user_data
)
Expand All @@ -41,7 +45,7 @@ def validate_keys_for_user(userdata, config_map, username, keys_to_delete):
keys_to_delete.append((username_to_validate, old_key, prompt))
else:
logging.info(
f" No get_new_key section for iam plugin for user {username_to_validate} - skipping"
f" No get_new_key or rotate_ses_smtp_user section for iam plugin for user {username_to_validate} - skipping"
)
else:
logging.info(
Expand Down Expand Up @@ -171,12 +175,18 @@ def verify_public_ip(required_public_ip):
if required_public_ip.lower() == "false":
logging.info("Skipping public IP verification.")
else:
logging.info(f"Checking if current public IP is {required_public_ip} (either in the office or on VPN)")
myip = requests.get('https://api.ipify.org').text
logging.info(
f"Checking if current public IP is {required_public_ip} (either in the office or on VPN)"
)
myip = requests.get("https://api.ipify.org").text
if myip == required_public_ip:
logging.info(f'Verified public IP address is: {myip} - LOCK will continue')
logging.info(
f"Verified public IP address is: {myip} - LOCK will continue"
)
else:
logging.error(f'Incorrect public IP detected ({myip}) - LOCK cannot continue')
logging.error(
f"Incorrect public IP detected ({myip}) - LOCK cannot continue"
)
sys.exit(1)


Expand Down
4 changes: 2 additions & 2 deletions project/plugins/1password.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def create_item(

def get_item_id(client: Client, vault_id: str, item_title: str) -> str:
item_id = None
items = asyncio.run(client.items.list(vault_id))
items = asyncio.run(client.items.list_all(vault_id)).obj
for item in items:
if item.title == item_title:
item_id = item.id
Expand All @@ -107,7 +107,7 @@ def get_item_id(client: Client, vault_id: str, item_title: str) -> str:

def get_vault_id(client: Client, vault_title: str) -> str:
vault_id = None
vaults = asyncio.run(client.vaults.list())
vaults = asyncio.run(client.vaults.list_all()).obj
for vault in vaults:
if vault.title == vault_title:
vault_id = vault.id
Expand Down
247 changes: 180 additions & 67 deletions project/plugins/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from azure.identity import ClientSecretCredential
from azure.keyvault.secrets import SecretClient
from azure.mgmt.compute import ComputeManagementClient
from azure.mgmt.resource import ResourceManagementClient
from azure.mgmt.monitor import MonitorManagementClient
from azure.mgmt.monitor.models import AutoscaleSettingResourcePatch
from copy import deepcopy
from msrest.authentication import BasicAuthentication
from project.plugins.ssh import ssh_server_command
from project import values
from time import sleep

import logging
import msrestazure.azure_exceptions

logging.getLogger("azure.keyvault.secrets").setLevel(logging.CRITICAL)
logging.getLogger("azure.mgmt.resource.resources").setLevel(logging.CRITICAL)
Expand All @@ -20,77 +21,189 @@
logging.getLogger("urllib3").setLevel(logging.CRITICAL)


def rotate_vms(config_map, username, **key_args):
def revert_changes(
username, autoscale_settings_operations, resource_group, autoscale_setting_resource
):
if values.DryRun:
logging.info(
f"User {username}: Dry run enabled. No changes to be reverted in scale set {autoscale_setting_resource.name}"
)
return

autoscale_settings_operations.update(
resource_group_name=resource_group,
autoscale_setting_name="autoscalehost",
autoscale_setting_resource=autoscale_setting_resource,
)


def poll_until_success(username, client, resource_group, min_instances, interval=30):
if values.DryRun:
logging.info(f"User {username}: Dry run enabled. Polling will be skipped")
return

vm_names = set()

logging.info(
f"User {username}: Waiting for number of instances to reach minimum value"
)
while len(vm_names) != min_instances:
virtual_machines = client.virtual_machines.list(
resource_group_name=resource_group
)

for vm in virtual_machines:
vm_names.add(vm.name)

logging.info(f"User {username}: There are currently {len(vm_names)} instances")

if len(vm_names) != min_instances:
logging.info(f"User {username}: Sleeping for {interval} seconds")
sleep(interval)

logging.info(f"User {username}: Minimum number of instances reached!")

logging.info(
f"User {username}: Waiting for instances to enter provisioning state 'Succeeded'"
)
for vm_name in vm_names:
state = ""
while state != "Succeeded":
state = client.virtual_machines.get(
resource_group_name=resource_group, vm_name=vm_name
).provisioning_state
logging.debug(f"{vm_name} State: {state}")

if state != "Succeeded":
logging.info(f"User {username}: Sleeping for {interval} seconds")
sleep(interval)
logging.info(
f"User {username}: {vm_name} is now in provisioning state 'Succeeded'"
)


def scale_up_instances(username, autoscale_settings_operations, resource_group):
logging.info(f"User {username}: Doubling minimum instance capacity")
autoscale_setting = autoscale_settings_operations.get(
resource_group_name=resource_group, autoscale_setting_name="autoscalehost"
)

profile = autoscale_setting.profiles[0]
capacity = profile.capacity

autoscale_setting_resource = AutoscaleSettingResourcePatch(
tags=autoscale_setting.tags,
profiles=autoscale_setting.profiles,
notifications=autoscale_setting.notifications,
enabled=autoscale_setting.enabled,
name=autoscale_setting.name,
target_resource_uri=autoscale_setting.target_resource_uri,
target_resource_location=autoscale_setting.target_resource_location,
)
old_autoscale_setting_resource = deepcopy(autoscale_setting_resource)

minimum = int(capacity.minimum)
maximum = int(capacity.maximum)
max_increase = maximum - minimum * 2

if max_increase < 0:
max_increase = 0

for rule in profile.rules:
if rule.scale_action.direction == "Increase":
increase_value = int(rule.scale_action.value)
rule.scale_action.value = (
increase_value if increase_value < max_increase else max_increase
)

if values.DryRun:
logging.info(
f"User {username}: Dry run enabled. Instances in scale set {autoscale_setting.name} will not be cycled"
)
return old_autoscale_setting_resource, capacity.minimum

if minimum * 2 <= maximum:
capacity.minimum = minimum * 2
capacity.default = capacity.minimum

autoscale_settings_operations.update(
resource_group_name=resource_group,
autoscale_setting_name="autoscalehost",
autoscale_setting_resource=autoscale_setting_resource,
)

return old_autoscale_setting_resource, capacity.minimum
else:
return None, None


def get_scale_sets_by_prefix(scale_set_prefix, resource_group_prefixes, scale_sets):
matched_scale_sets = []

for scale_set in scale_sets:
resource_group = scale_set.id.split("/")[4].lower()
if not any(
[resource_group.startswith(prefix) for prefix in resource_group_prefixes]
):
continue

if scale_set_prefix in scale_set.name:
matched_scale_sets.append(
{"Name": scale_set.name, "ResourceGroup": resource_group}
)

return matched_scale_sets


def rotate_vms(config_map, username, **key_args) -> None:
auth = config_map["Global"]["azure_credentials"][key_args.get("account")]
credentials = ClientSecretCredential(
auth.get("tenant"), auth.get("client_id"), auth.get("secret")
)
subscriptions = key_args.get("resource_group_subscriptionid")
scale_set_prefix = key_args.get("scale_set_prefix")
subscriptions = key_args.get("subscriptions")

for subscription in subscriptions:
region = None
to_rotate = []
for region in subscription:
for resource_group_name_prefix, subscription_id in subscription.get(
region
).items():
resource_client = ResourceManagementClient(
credentials, subscription_id, logging_enable=False # type: ignore
logging.info(
f"User {username}: Cycling instances under subscription {subscription['SubscriptionId']}..."
)

monitor_client = MonitorManagementClient(
credential=credentials, subscription_id=subscription["SubscriptionId"] # type: ignore
)
autoscale_settings_operations = monitor_client.autoscale_settings

compute_client = ComputeManagementClient(
credential=credentials, subscription_id=subscription["SubscriptionId"] # type: ignore
)
response = compute_client.virtual_machine_scale_sets.list_all()

resource_group_prefixes = subscription.get("ResourceGroupPrefixes")
scale_sets = get_scale_sets_by_prefix(
scale_set_prefix, resource_group_prefixes, response
)

for scale_set in scale_sets:
old_autoscale_setting_resource, min_instances = scale_up_instances(
username, autoscale_settings_operations, scale_set["ResourceGroup"]
)

if old_autoscale_setting_resource is None and min_instances is None:
logging.info(
f"User {username}: No changes made to scale set {scale_set['Name']}. Skipping..."
)
compute_client = ComputeManagementClient(credentials, subscription_id) # type: ignore

resource_groups = resource_client.resource_groups.list()
matching_resource_groups = [
rg
for rg in resource_groups
if rg.name.startswith(resource_group_name_prefix)
]
for matching_resource_group in matching_resource_groups:
resource_group_name = matching_resource_group.name
resources = resource_client.resources.list_by_resource_group(
resource_group_name
)
for resource in resources:
if resource.type == "Microsoft.Compute/virtualMachines":
try:
result = compute_client.virtual_machines.get(
resource_group_name,
resource.name,
expand="instanceView",
)
if (
len(result.instance_view.statuses) > 1
and "running"
in result.instance_view.statuses[1].display_status
and result.instance_view.computer_name
):
to_rotate.append(result.instance_view.computer_name)
else:
logging.warning(
f"User {username}: {resource.name} Not in RUNNING state - skipping"
)
except msrestazure.azure_exceptions.CloudError as e:
if "not found" in e.message:
logging.warning(
f"User {username}: {resource.name} Not Found - skipping"
)

logging.info(f"User {username}: Found the following VMs: {to_rotate}")
# Build dns names
for vm in to_rotate:
markers = []
commands = []
key_args["hostname"] = key_args.get("f_host").replace("<SERVER>", vm)
logging.info(f"User {username}: Writing key to " + key_args["hostname"])
for pkey in key_args.get("pkeys"):
if region.replace("-", "") in pkey:
key_args["pkey"] = pkey
for marker in key_args.get("fadmin_markers_commands"):
markers.append(marker)
commands.append(key_args.get("fadmin_markers_commands").get(marker))
key_args["commands"] = commands
key_args["markers"] = markers
ssh_server_command(config_map, username, **key_args)
continue

poll_until_success(
username, compute_client, scale_set["ResourceGroup"], min_instances
)

revert_changes(
username,
autoscale_settings_operations,
scale_set["ResourceGroup"],
old_autoscale_setting_resource,
)


def set_key_vault(config_map, username, **key_args):
Expand Down
12 changes: 6 additions & 6 deletions project/plugins/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,16 @@ def wait_for_operation(username, compute, project, region, operation):
def rotate_instance_groups(config_map, username, **key_args):
auth = config_map["Global"]["google_credentials"]["client_cred"]
regions = key_args.get("regions")
rotate_gcp_instance_group(username, auth, regions, 2)
rotate_gcp_instance_group(username, auth, regions)


def rotate_fg_instance_groups(config_map, username, **key_args):
auth = config_map["Global"]["google_credentials"]["fg_cred"]
regions = key_args.get("regions")
rotate_gcp_instance_group(username, auth, regions, 1)
rotate_gcp_instance_group(username, auth, regions)


def rotate_gcp_instance_group(username, auth, regions, max_unavailable):
def rotate_gcp_instance_group(username, auth, regions, max_unavailable=2):
credentials = service_account.Credentials.from_service_account_file(auth)
# authenticate with compute api
try:
Expand Down Expand Up @@ -151,10 +151,10 @@ def rotate_gcp_instance_group(username, auth, regions, max_unavailable):
"updatePolicy": {
"minimalAction": "REPLACE",
"type": "PROACTIVE",
"maxSurge": {"fixed": 0},
"maxUnavailable": {"fixed": max_unavailable},
"maxSurge": {"fixed": 3},
"maxUnavailable": {"fixed": 0},
"minReadySec": 300,
"replacementMethod": "recreate",
"replacementMethod": "substitute",
},
"versions": [{"instanceTemplate": instance_template, "name": version}],
}
Expand Down
Loading