Skip to content

Added logic for storing Armis Access Token into Keyvault and Added Severity parameter into AlertActivity data connector #12193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@
import json
from .sentinel import AzureSentinel
from .exports_store import ExportsTableStore
from Exceptions.ArmisExceptions import ArmisException, ArmisDataNotFoundException
from Exceptions.ArmisExceptions import ArmisException, ArmisDataNotFoundException, ArmisTimeOutException
from .utils import Utils
from . import consts
import inspect
import time


class ArmisAlertsActivities(Utils):
"""This class will process the Alert Activity data and post it into the Microsoft sentinel."""

def __init__(self):
def __init__(self, start_time):
"""__init__ method will initialize object of class."""
super().__init__()
self.start_time = start_time
self.data_alert_from = 0
self.azuresentinel = AzureSentinel()
self.total_alerts_posted = 0
Expand Down Expand Up @@ -174,6 +176,36 @@ def post_alert_activity_data(
)
raise ArmisException()

def process_large_chunks_of_activity_data(self, activity_uuids):
"""Process large chunks of activity data for specific alert.

Args:
activity_uuids (list): list of activity uuids
"""
__method_name = inspect.currentframe().f_code.co_name
try:
for index in range(0, len(activity_uuids), consts.CHUNK_SIZE):
if int(time.time()) >= self.start_time + consts.FUNCTION_APP_TIMEOUT_SECONDS:
raise ArmisTimeOutException()
chunk_of_activity_uuids = activity_uuids[index: index + consts.CHUNK_SIZE]
activity_data = self.get_activity_data(chunk_of_activity_uuids)
self.azuresentinel.post_data(
json.dumps(activity_data, indent=2),
consts.ARMIS_ACTIVITIES_TABLE,
"armis_activity_time",
)
self.total_activities_posted += len(activity_data)
logging.info(
consts.LOG_FORMAT.format(
__method_name, "Posted Activities count : {}.".format(len(activity_data))
)
)
except ArmisException:
raise ArmisException()

except ArmisTimeOutException:
raise ArmisTimeOutException()

def process_alerts_data(self, alerts, offset_to_post, checkpoint_table_object: ExportsTableStore):
"""Process alerts data to fetch related activity.

Expand All @@ -185,6 +217,8 @@ def process_alerts_data(self, alerts, offset_to_post, checkpoint_table_object: E
activity_uuid_list = []
alerts_data_to_post = []
for alert in alerts:
if int(time.time()) >= self.start_time + consts.FUNCTION_APP_TIMEOUT_SECONDS:
raise ArmisTimeOutException()
activity_uuids = alert.get("activityUUIDs", [])
if len(activity_uuid_list) + len(activity_uuids) <= consts.CHUNK_SIZE:
activity_uuid_list.extend(activity_uuids)
Expand All @@ -200,23 +234,13 @@ def process_alerts_data(self, alerts, offset_to_post, checkpoint_table_object: E
alerts_data_to_post.append(alert)
else:
logging.info(
consts.LOG_FORMAT.format(
__method_name, "Chunk size is greater than {}.".format(consts.CHUNK_SIZE))
)
for index in range(0, len(activity_uuids), consts.CHUNK_SIZE):
chunk_of_activity_uuids = activity_uuids[index: index + consts.CHUNK_SIZE]
activity_data = self.get_activity_data(chunk_of_activity_uuids)
self.azuresentinel.post_data(
json.dumps(activity_data, indent=2),
consts.ARMIS_ACTIVITIES_TABLE,
"armis_activity_time",
)
self.total_activities_posted += len(activity_data)
logging.info(
consts.LOG_FORMAT.format(
__method_name, "Posted Activities count : {}.".format(len(activity_data))
)
consts.LOG_FORMAT.format(
__method_name, "Chunk size is greater than {}.".format(consts.CHUNK_SIZE)
)
)
self.process_large_chunks_of_activity_data(
activity_uuids
)
self.azuresentinel.post_data(
json.dumps([alert], indent=2), consts.ARMIS_ALERTS_TABLE, "armis_alert_time"
)
Expand All @@ -237,6 +261,9 @@ def process_alerts_data(self, alerts, offset_to_post, checkpoint_table_object: E
except ArmisException:
raise ArmisException()

except ArmisTimeOutException:
raise ArmisTimeOutException()

except Exception as err:
logging.error(
consts.LOG_FORMAT.format(
Expand All @@ -257,13 +284,22 @@ def fetch_alert_data(
"""
__method_name = inspect.currentframe().f_code.co_name
try:
aql_with_severity = "in:alerts"
if consts.SEVERITY in consts.SEVERITIES:
severity_index = consts.SEVERITIES.index(consts.SEVERITY)
included_severities = ",".join(consts.SEVERITIES[severity_index:])
aql_with_severity = f"in:alerts severity:{included_severities}"
else:
raise ValueError()
if is_checkpoint_not_exist:
aql_data = "in:alerts"
aql_data = aql_with_severity
else:
aql_data = """{} after:{}""".format("in:alerts", last_time)
aql_data = """{} after:{}""".format(aql_with_severity, last_time)
alert_parameter["aql"] = aql_data
alert_parameter["length"] = 1000
while self.data_alert_from is not None:
if int(time.time()) >= self.start_time + consts.FUNCTION_APP_TIMEOUT_SECONDS:
raise ArmisTimeOutException()
alert_parameter.update({"from": self.data_alert_from})
offset_to_post = self.data_alert_from
logging.info(consts.LOG_FORMAT.format(__method_name, "Fetching alerts data with parameters = {}.".format(alert_parameter)))
Expand Down Expand Up @@ -295,6 +331,17 @@ def fetch_alert_data(
except ArmisDataNotFoundException:
raise ArmisDataNotFoundException()

except ArmisTimeOutException:
raise ArmisTimeOutException()

except ValueError:
logging.error(
consts.LOG_FORMAT.format(
__method_name, "Value Error occurred, Severity value is not from the list 'Low', 'Medium', 'High', 'Critical'."
)
)
raise ArmisException()

except Exception as err:
logging.error(consts.LOG_FORMAT.format(__method_name, "Error occurred : {}.".format(err)))
raise ArmisException()
Expand Down Expand Up @@ -374,6 +421,14 @@ def check_data_exists_or_not_alert(self):
except ArmisException:
raise ArmisException()

except ArmisTimeOutException:
logging.error(
consts.LOG_FORMAT.format(
__method_name, "9:30 mins executed hence stopping the execution"
)
)
return

except ArmisDataNotFoundException:
raise ArmisDataNotFoundException()

Expand All @@ -399,8 +454,8 @@ def main(mytimer: func.TimerRequest) -> None:
logging.info(
consts.LOG_FORMAT.format(__method_name, "Python timer trigger function ran at {}".format(utc_timestamp))
)

armis_obj = ArmisAlertsActivities()
start_time = time.time()
armis_obj = ArmisAlertsActivities(start_time)
try:
armis_obj.check_data_exists_or_not_alert()
except ArmisDataNotFoundException:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,21 @@
"activityUUIDs",
]
RETRY_COUNT_401 = 3
SEVERITY = os.environ.get("Severity", "Low")
SEVERITIES = ["Low", "Medium", "High", "Critical"]

# Sentinel constants
CONNECTION_STRING = os.environ.get("AzureWebJobsStorage", "")
ARMIS_ALERTS_TABLE = os.environ.get("ArmisAlertsTableName", "")
ARMIS_ACTIVITIES_TABLE = os.environ.get("ArmisActivitiesTableName", "")
WORKSPACE_ID = os.environ.get("WorkspaceID", "")
WORKSPACE_KEY = os.environ.get("WorkspaceKey", "")
KEYVAULT_NAME = os.environ.get("KeyVaultName", "")
CHUNK_SIZE = 35
FILE_SHARE = "funcstatemarkershare"
CHECKPOINT_FILE_TIME = "funcarmisalertsfile"
CHECKPOINT_FILE_OFFSET = "armisalertoffset"
LOG_FORMAT = "Armis Alerts Activities Connector: (method = {}) : {}"
REQUEST_TIMEOUT = 300
CHECKPOINT_TABLE_NAME = "ArmisAlertActivityCheckpoint"
FUNCTION_APP_TIMEOUT_SECONDS = 570
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"version": "2.0",
"extensionBundle": {
"id": "Microsoft.Azure.Functions.ExtensionBundle",
"version": "[4.*, 5.0.0)"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""This file is used for accessing keyvault to get or set secrets."""
import logging
from azure.keyvault.secrets import SecretClient
from azure.identity import DefaultAzureCredential
from azure.core.exceptions import ResourceNotFoundError
from . import consts


class KeyVaultSecretManager:
"""This class contains methods to authenticate with Azure KeyVault and get or set secrets in keyvault."""

def __init__(self) -> None:
"""Intialize instance variables for class."""
self.keyvault_name = consts.KEYVAULT_NAME
self.keyvault_uri = "https://{}.vault.azure.net/".format(self.keyvault_name)
self.client = self.get_client()

def get_client(self):
"""To obtain AzureKeyVault client.

Returns:
SecretClient: returns client object for accessing AzureKeyVault.
"""
credential = DefaultAzureCredential()
client = SecretClient(vault_url=self.keyvault_uri, credential=credential)
return client

def get_keyvault_secret(self, secret_name):
"""To get value of provided secretname from AzureKeyVault.

Args:
secret_name (str): secret name to get its value.
"""
try:
logging.info("Retrieving secret {} from {}.".format(secret_name, self.keyvault_name))
retrieved_secret = self.client.get_secret(secret_name)
logging.info("Retrieved secret value for {}.".format(retrieved_secret.name))
return retrieved_secret.value

except ResourceNotFoundError as err:
logging.error("Resource not found : '{}' ".format(err))
self.set_keyvault_secret(secret_name, "")
return ""

def set_keyvault_secret(self, secret_name, secret_value):
"""To update secret value of given secret name or create new secret.

Args:
secret_name (str): secret name to update its value or create it.
secret_value (str): secret value to be set as value of given secret name.
"""
logging.info("Creating or updating a secret '{}'.".format(secret_name))
self.client.set_secret(secret_name, secret_value)
logging.info("Secret created successfully : '{}' .".format(secret_name))

def get_properties_list_of_secrets(self):
"""To get list of secrets stored in keyvault with its properties.

Returns:
list: _description_
"""
secret_properties = self.client.list_properties_of_secrets()
properties_list = [secret_property.name for secret_property in secret_properties]
return properties_list
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import requests
from . import consts
from .state_manager import StateManager
from .keyvault_secrets_management import KeyVaultSecretManager


class Utils:
Expand All @@ -27,7 +28,14 @@ def __init__(self) -> None:
]
)
self._secret_key = consts.API_KEY
self.get_access_token()
self.keyvault_obj = KeyVaultSecretManager()
self.access_token_key = "armis-access-token"
properties_list = self.keyvault_obj.get_properties_list_of_secrets()
if self.access_token_key in properties_list:
self.access_token = self.keyvault_obj.get_keyvault_secret(self.access_token_key)
self.header.update({"Authorization": self.access_token})
else:
self.get_access_token()
self.state_manager_obj = StateManager(
connection_string=consts.CONNECTION_STRING, file_path=consts.CHECKPOINT_FILE_TIME
)
Expand Down Expand Up @@ -61,6 +69,35 @@ def check_environment_var_exist(self, environment_var):
)
raise ArmisException()

def compare_access_token(self):
"""compare_access_token will compare the current access token with the access token stored in keyvault
and update the header for further use.
"""
__method_name = inspect.currentframe().f_code.co_name
try:
keyvault_access_token = self.keyvault_obj.get_keyvault_secret(self.access_token_key)
header_access_token = self.header.get("Authorization")
if keyvault_access_token == header_access_token:
logging.info(consts.LOG_FORMAT.format(
__method_name, "KeyVault Access Token Invalid. Generating New Token."
))
self.get_access_token()
else:
logging.info(consts.LOG_FORMAT.format(
__method_name, "KeyVault Access Token Updated. Updating Header Value."
))
self.header.update({"Authorization": keyvault_access_token})
except ArmisException:
raise ArmisException()
except Exception as err:
logging.error(
consts.LOG_FORMAT.format(
__method_name,
"Unexpected error : {}.".format(err),
)
)
raise ArmisException()

def make_rest_call(self, method, url, params=None, headers=None, data=None, retry_401=0):
"""Make a rest call.

Expand Down Expand Up @@ -103,7 +140,7 @@ def make_rest_call(self, method, url, params=None, headers=None, data=None, retr
__method_name, "Unauthorized, Status code : {}, Retrying...".format(response.status_code)
)
)
self.get_access_token()
self.compare_access_token()
self.retry_count += 1
continue
elif response.status_code == 429:
Expand Down Expand Up @@ -230,6 +267,7 @@ def get_access_token(self):
response = self.make_rest_call(method="POST", url=consts.URL + consts.ACCESS_TOKEN_SUFFIX, data=body)
access_token = response.get("data", {}).get("access_token")
self.header.update({"Authorization": access_token})
self.keyvault_obj.set_keyvault_secret(self.access_token_key, access_token)
logging.info(consts.LOG_FORMAT.format(__method_name, "Generated access token Successfully."))
except KeyError as err:
logging.error(consts.LOG_FORMAT.format(__method_name, "Key error : {}.".format(err)))
Expand Down
Binary file not shown.
Loading
Loading