Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cosmo/clients/queries/device.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ query {

device_type {
__typename
manufacturer {
__typename
slug
}
slug
}
platform {
Expand Down
49 changes: 36 additions & 13 deletions cosmo/manufacturers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@


from cosmo.common import DeviceSerializationError
from cosmo.netbox_types import DeviceType, InterfaceType, PlatformType, VRFType
from cosmo.netbox_types import (
DeviceType,
InterfaceType,
PlatformType,
VRFType,
DeviceTypeType,
)


class AbstractManufacturer(ABC):
Expand All @@ -16,22 +22,39 @@ def __init__(self, cosmo_config: "CosmoConfig"):

@classmethod
def isCompatibleWith(cls, device: DeviceType):
# Note: If the platform cannot be parsed, getPlatform will be a string.
if not isinstance(device.getDeviceType(), DeviceTypeType):
return False

if not isinstance(device.getPlatform(), PlatformType):
return False

device_platform_match = False
if device.getPlatform().getManufacturer():
return (
device_platform_match = (
device.getPlatform().getManufacturer().getSlug()
== cls.myManufacturerSlug()
in cls.myManufacturerSlugs()
)
else:
device_platform_match = bool(
re.match(cls.myPlatformRE(), device.getPlatform().getSlug())
)

device_manufacturer_match = False
if device.getDeviceType().getManufacturer():
device_manufacturer_match = (
device.getDeviceType().getManufacturer().getSlug()
in cls.myManufacturerSlugs()
)
else:
# fallback in case no manufacturer is filled in for the platform
return re.match(cls.myPlatformRE(), device.getPlatform().getSlug())
device_manufacturer_match = bool(
re.match(cls.myPlatformRE(), device.getDeviceType().getSlug())
)

return device_platform_match or device_manufacturer_match

@staticmethod
@abstractmethod
def myManufacturerSlug():
def myManufacturerSlugs() -> list[str]:
pass

@classmethod
Expand Down Expand Up @@ -110,8 +133,8 @@ class JuniperManufacturer(AbstractJuniperRtBrickManufacturerCommon):
_platform_re = re.compile(r"REPLACEME")

@staticmethod
def myManufacturerSlug():
return "juniper"
def myManufacturerSlugs():
return ["juniper"]

@classmethod
def myPlatformRE(cls):
Expand All @@ -138,8 +161,8 @@ class RtBrickManufacturer(AbstractJuniperRtBrickManufacturerCommon):
_platform_re = re.compile(r"REPLACEME")

@staticmethod
def myManufacturerSlug():
return "rtbrick"
def myManufacturerSlugs():
return ["rtbrick", "ufispace", "edgecore"]

@classmethod
def myPlatformRE(cls):
Expand Down Expand Up @@ -167,8 +190,8 @@ class CumulusNetworksManufacturer(AbstractManufacturer):
_platform_re = re.compile(r"^cumulus-linux[a-zA-Z0-9-]*")

@staticmethod
def myManufacturerSlug():
return "cumulus-networks"
def myManufacturerSlugs():
return ["cumulus-networks"]

@classmethod
def myPlatformRE(cls):
Expand Down
5 changes: 4 additions & 1 deletion cosmo/netbox_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,13 +301,16 @@ class DeviceTypeType(AbstractNetboxType):
def getBasePath(self):
return "/dcim/device-types/"

def getManufacturer(self):
return self.get("manufacturer")


class PlatformType(AbstractNetboxType):
def getBasePath(self):
return "/dcim/platforms/"

def getManufacturer(self):
return self["manufacturer"]
return self.get("manufacturer")


class ManufacturerType(AbstractNetboxType):
Expand Down
8 changes: 6 additions & 2 deletions cosmo/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_router_platforms(mock_cosmo_config_fixture, mock_global_vrf, mock_l3vpn
DeviceType(juniper_s.device), mock_cosmo_config_fixture
).get()
assert juniper_manufacturer.getManagementVRFName() == "mgmt_junos"
assert juniper_manufacturer.myManufacturerSlug() == "juniper"
assert juniper_manufacturer.myManufacturerSlugs() == ["juniper"]
assert (
juniper_manufacturer.spitVRFPathWith(mock_global_vrf, {})
== juniper_manufacturer._spitDefaultVRFPathWith({})
Expand Down Expand Up @@ -146,7 +146,11 @@ def test_router_platforms(mock_cosmo_config_fixture, mock_global_vrf, mock_l3vpn
DeviceType(rtbrick_s.device), mock_cosmo_config_fixture
).get()
assert rtbrick_manufacturer.getManagementVRFName() == "mgmt"
assert rtbrick_manufacturer.myManufacturerSlug() == "rtbrick"
assert rtbrick_manufacturer.myManufacturerSlugs() == [
"rtbrick",
"ufispace",
"edgecore",
]
assert (
rtbrick_manufacturer.spitVRFPathWith(mock_global_vrf, {})
== rtbrick_manufacturer._spitDefaultVRFPathWith({})
Expand Down
Loading