Verified Commit 12f4ba0c authored by Dmitry Volodin's avatar Dmitry Volodin
Browse files

Unified loader interface

parent 34fa3968
......@@ -95,7 +95,7 @@ class Command(BaseCommand):
model = self.MODELS.get(datastream)
if not model:
self.die("Unsupported datastream")
ds = loader.get_datastream(datastream)
ds = loader[datastream]
if not ds:
self.die("Cannot initialize datastream")
total = self.get_total(model)
......@@ -112,7 +112,7 @@ class Command(BaseCommand):
def handle_get(self, datastream, objects, filter, *args, **kwargs):
if not datastream:
self.die("--datastream is not set. Set one from list: %s" % self.MODELS.keys())
ds = loader.get_datastream(datastream)
ds = loader[datastream]
if not ds:
self.die("Cannot initialize datastream")
filter = filter or []
......
......@@ -94,7 +94,7 @@ class Command(BaseCommand):
for name in ds_loader.iter_datastreams():
if not getattr(config.datastream, "enable_%s" % name, False):
continue
ds = ds_loader.get_datastream(name)
ds = ds_loader[name]
self.print("[%s] Indexing datastream" % ds.name)
ds.ensure_collection()
......
......@@ -8,22 +8,20 @@
# Python modules
from __future__ import absolute_import
import os
import logging
# NOC modules
from noc.config import config
from .loader import loader
logger = logging.getLogger(__name__)
BIMODELS_PREFIX = os.path.join("bi", "models")
def ensure_bi_models(connect=None):
logger.info("Ensuring BI models:")
# Ensure fields
changed = False
for name in loader.iter_models():
model = loader.get_model(name)
for name in loader:
model = loader[name]
if not model:
continue
logger.info("Ensure table %s" % model._meta.db_table)
......
......@@ -8,80 +8,21 @@
# Python modules
from __future__ import absolute_import
import os
import threading
# NOC modules
from noc.config import config
from noc.core.loader.base import BaseLoader
from .model import Model
class ModelLoader(BaseLoader):
name = "bi"
IGNORED_MODELS = {"dashboard", "dashboardlayout"}
def __init__(self):
super(ModelLoader, self).__init__()
self.models = {} # Load models
self.lock = threading.Lock()
self.all_models = set()
def get_model(self, name):
with self.lock:
model = self.models.get(name)
if not model:
self.logger.info("Loading loader %s", name)
if not self.is_valid_name(name):
self.logger.error("Invalid loader name: %s", name)
return None
for p in config.get_customized_paths("", prefer_custom=True):
path = os.path.join(p, "bi", "models", "%s.py" % name)
if not os.path.exists(path):
continue
if p:
# Customized model
base_name = os.path.basename(os.path.dirname(p))
module_name = "%s.bi.models.%s" % (base_name, name)
else:
# Common model
module_name = "noc.bi.models.%s" % name
model = self.find_class(module_name, Model, name)
if model:
if not hasattr(model, "_meta"):
self.logger.error("Model %s has no _meta", name)
continue
if getattr(model._meta, "db_table", None) != name:
self.logger.error("Table name mismatch")
continue
break
if not model:
self.logger.error("Model not found: %s", name)
self.models[name] = model
return model
def is_valid_name(self, name):
return ".." not in name
def iter_models(self):
with self.lock:
if not self.all_models:
self.all_models = self.find_models()
for ds in sorted(self.all_models):
yield ds
def find_models(self):
"""
Scan all available models
"""
names = set()
for dn in config.get_customized_paths(os.path.join("bi", "models")):
for file in os.listdir(dn):
if file.startswith("_") or not file.endswith(".py"):
continue
name = file[:-3]
if name not in self.IGNORED_MODELS:
names.add(file[:-3])
return names
base_cls = Model
base_path = ("bi", "models")
ignored_names = {"dashboard", "dashboardlayout"}
def is_valid_class(self, kls, name):
if not hasattr(kls, "_meta"):
return False
return getattr(kls._meta, "db_table", None) == name
# Create singleton object
......
......@@ -74,7 +74,7 @@ def update_object(ds_name, object_id):
:param object_id:
:return:
"""
ds = loader.get_datastream(ds_name)
ds = loader[ds_name]
if not ds:
return
r = ds.update_object(object_id)
......
......@@ -8,76 +8,15 @@
# Python modules
from __future__ import absolute_import
import logging
import threading
import os
# NOC modules
from noc.config import config
from noc.core.loader.base import BaseLoader
from .base import DataStream
logger = logging.getLogger(__name__)
class DataStreamLoader(BaseLoader):
def __init__(self):
super(DataStreamLoader, self).__init__()
self.datastreams = {} # Load datastreams
self.lock = threading.Lock()
self.all_datastreams = set()
def get_datastream(self, name):
"""
Load datastream and return DataStream instance.
Returns None when no datastream found or loading error occured
"""
with self.lock:
datastream = self.datastreams.get(name)
if not datastream:
logger.info("Loading datastream %s", name)
if not self.is_valid_name(name):
logger.error("Invalid datastream name")
return None
for p in config.get_customized_paths("", prefer_custom=True):
path = os.path.join(p, "services", "datastream", "streams", "%s.py" % name)
if not os.path.exists(path):
continue
if p:
# Customized datastream
base_name = os.path.basename(os.path.dirname(p))
module_name = "%s.services.datastream.streams.%s" % (base_name, name)
else:
# Common datastream
module_name = "noc.services.datastream.streams.%s" % name
datastream = self.find_class(module_name, DataStream, name)
if datastream:
break
if not datastream:
logger.error("DataStream not found: %s", name)
self.datastreams[name] = datastream
return datastream
def is_valid_name(self, name):
return ".." not in name
def iter_datastreams(self):
with self.lock:
if not self.all_datastreams:
self.all_datastreams = self.find_datastreams()
for ds in sorted(self.all_datastreams):
yield ds
def find_datastreams(self):
"""
Scan all available datastreams
"""
names = set()
for dn in config.get_customized_paths(os.path.join("services", "datastream", "streams")):
for file in os.listdir(dn):
if file.startswith("_") or not file.endswith(".py"):
continue
names.add(file[:-3])
return names
name = "datastream"
base_cls = DataStream
base_path = ("services", "datastream", "streams")
# Create singleton object
......
......@@ -9,7 +9,10 @@
# Python modules
import logging
import inspect
import threading
import os
# NOC modules
from noc.config import config
from noc.core.log import PrefixLoggerAdapter
logger = logging.getLogger(__name__)
......@@ -17,9 +20,15 @@ logger = logging.getLogger(__name__)
class BaseLoader(object):
name = None
base_cls = None # Base class to be loaded
base_path = None # Tuple of path components
ignored_names = set()
def __init__(self):
self.logger = PrefixLoggerAdapter(logger, self.name)
self.classes = {}
self.lock = threading.Lock()
self.all_classes = set()
def find_class(self, module_name, base_cls, name):
"""
......@@ -37,9 +46,87 @@ class BaseLoader(object):
if (
inspect.isclass(o) and
issubclass(o, base_cls) and
o.__module__ == sm.__name__
o.__module__ == sm.__name__ and
self.is_valid_class(o, name)
):
return o
except ImportError as e:
self.logger.error("Failed to load %s %s: %s", self.name, name, e)
return None
def is_valid_class(self, kls, name):
"""
Check `find_class` found valid class
:param kls: Class
:param name: Class' name
:return: True if class is valid and should be returned
"""
return True
def is_valid_name(self, name):
return ".." not in name
def get_path(self, base, name):
"""
Get file path
:param base: "" or custom prefix
:param name: class name
:return:
"""
p = (base,) + self.base_path + ("%s.py" % name,)
return os.path.join(*p)
def get_module_name(self, base, name):
"""
Get module name
:param base: `noc` or custom prefix
:param name: module name
:return:
"""
return "%s.%s.%s" % (base, ".".join(self.base_path), name)
def get_class(self, name):
with self.lock:
kls = self.classes.get(name)
if not kls:
self.logger.info("Loading %s", name)
if not self.is_valid_name(name):
self.logger.error("Invalid name: %s", name)
return None
for p in config.get_customized_paths("", prefer_custom=True):
path = self.get_path(p, name)
if not os.path.exists(path):
continue
base_name = os.path.basename(os.path.dirname(p)) if p else "noc"
module_name = self.get_module_name(base_name, name)
kls = self.find_class(module_name, self.base_cls, name)
if kls:
break
if not kls:
logger.error("DataStream not found: %s", name)
self.classes[name] = kls
return kls
def __getitem__(self, item):
return self.get_class(item)
def __iter__(self):
return self.iter_classes()
def iter_classes(self):
with self.lock:
if not self.all_classes:
self.all_classes = self.find_classes()
for ds in sorted(self.all_classes):
yield ds
def find_classes(self):
names = set()
for dn in config.get_customized_paths(os.path.join(*self.base_path)):
for fn in os.listdir(dn):
if fn.startswith("_") or not fn.endswith(".py"):
continue
name = fn[:-3]
if name not in self.ignored_names:
names.add(name)
return names
......@@ -128,8 +128,8 @@ class BIAPI(API):
@classmethod
def get_bi_datasources(cls):
result = []
for mn in loader.iter_models():
model = loader.get_model(mn)
for mn in loader:
model = loader[mn]
if not model:
continue
r = {
......@@ -167,7 +167,7 @@ class BIAPI(API):
lock=lambda _: model_lock)
def get_model(cls, name):
# Static datasource
model = loader.get_model(name)
model = loader[name]
if model:
return model
# Dynamic datasource
......@@ -388,7 +388,7 @@ class BIAPI(API):
if "field_name" not in params:
metrics["error", ("type", "get_hierarchy_no_field_name")] += 1
raise APIError("No field name")
model = loader.get_model(params["datasource"])
model = loader[params["datasource"]]
if not model:
metrics["error", ("type", "get_hierarchy_invalid_datasource")] += 1
raise APIError("Invalid datasource")
......
......@@ -36,10 +36,10 @@ class DataStreamService(Service):
def get_datastreams(self):
r = []
for name in loader.iter_datastreams():
for name in loader:
if not getattr(config.datastream, "enable_%s" % name, False):
continue
ds = loader.get_datastream(name)
ds = loader[name]
if ds:
self.logger.info("[%s] Initializing datastream", name)
r += [ds]
......@@ -90,8 +90,8 @@ class DataStreamService(Service):
:return: True if .watch() is working
"""
# Get one datastream collection
dsn = next(loader.iter_datastreams())
ds = loader.get_datastream(dsn)
dsn = next(loader)
ds = loader[dsn]
coll = ds.get_collection()
# Check pymongo has .watch
if not hasattr(coll, "watch"):
......
......@@ -266,31 +266,30 @@ def test_datastream_clean_id_int():
DS.clean_id("z")
@pytest.fixture(params=["managedobject", "administrativedomain"])
@pytest.fixture(params=list(loader))
def datastream_name(request):
return request.param
def test_loader(datastream_name):
ds = loader.get_datastream(datastream_name)
ds = loader[datastream_name]
assert ds is not None
assert issubclass(ds, DataStream)
assert ds.name == datastream_name
def test_loader_invalid_name():
ds = loader.get_datastream("aaa..bbbb")
ds = loader["aaa..bbbb"]
assert ds is None
def test_loader_error():
ds = loader.get_datastream("invalid")
ds = loader["invalid"]
assert ds is None
def test_loader_iter_datastreams(datastream_name):
dses = set(loader.iter_datastreams())
assert datastream_name in dses
def test_loader_contains(datastream_name):
assert datastream_name in loader
def test_wait():
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment