Skip to content

Commit dc2a71c

Browse files
committed
last changes
1 parent 591cda1 commit dc2a71c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1773
-28
lines changed

examples/cifar_simple/catalyst.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ stages:
3535
accuracy_args: [1]
3636
scheduler:
3737
callback: SchedulerCallback
38-
reduce_metric: *reduce_metric
38+
reduced_metric: *reduce_metric
3939
saver:
4040
callback: CheckpointCallback
4141

mlcomp/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '20.3'
1+
__version__ = '20.3.1a'

mlcomp/db/models/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
from .docker import Docker
1111
from .model import Model
1212
from .auxilary import Auxiliary
13+
from .memory import Memory
14+
from .space import Space
1315

1416
__all__ = [
1517
'Project', 'Task', 'TaskDependence', 'File', 'DagStorage', 'DagLibrary',
1618
'Computer', 'ComputerUsage', 'Log', 'Step', 'Dag', 'ReportSeries',
1719
'ReportImg', 'ReportTasks', 'Report', 'ReportLayout', 'Docker', 'Model',
18-
'Auxiliary', 'TaskSynced'
20+
'Auxiliary', 'TaskSynced', 'Memory', 'Space'
1921
]

mlcomp/db/models/memory.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import sqlalchemy as sa
2+
3+
from mlcomp.db.models.base import Base
4+
5+
6+
class Memory(Base):
7+
__tablename__ = 'memory'
8+
9+
id = sa.Column(sa.Integer, primary_key=True)
10+
model = sa.Column(sa.String, nullable=False)
11+
variant = sa.Column(sa.String)
12+
num_classes = sa.Column(sa.Integer)
13+
img_size = sa.Column(sa.Integer)
14+
batch_size = sa.Column(sa.Integer, nullable=False)
15+
memory = sa.Column(sa.Float, nullable=False)
16+
17+
18+
__all__ = ['Memory']

mlcomp/db/models/space.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import sqlalchemy as sa
2+
from sqlalchemy import ForeignKey
3+
4+
from mlcomp.db.models.base import Base
5+
6+
7+
class Space(Base):
8+
__tablename__ = 'space'
9+
10+
name = sa.Column(sa.String, nullable=False, primary_key=True)
11+
created = sa.Column(sa.DateTime, nullable=False)
12+
changed = sa.Column(sa.DateTime, nullable=False)
13+
content = sa.Column(sa.String, nullable=False)
14+
15+
16+
class SpaceRelation(Base):
17+
__tablename__ = 'space_relation'
18+
19+
parent = sa.Column(sa.String, ForeignKey('space.name'),
20+
primary_key=True)
21+
child = sa.Column(sa.String, ForeignKey('space.name'),
22+
primary_key=True)
23+
24+
25+
__all__ = ['Space', 'SpaceRelation']

mlcomp/db/providers/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
from .model import ModelProvider
1717
from .auxiliary import AuxiliaryProvider
1818
from .task_synced import TaskSyncedProvider
19+
from .memory import MemoryProvider
20+
from .space import SpaceProvider
1921

2022
__all__ = [
2123
'ProjectProvider', 'TaskProvider', 'FileProvider', 'DagStorageProvider',
2224
'DagLibraryProvider', 'LogProvider', 'StepProvider', 'ComputerProvider',
2325
'DagProvider', 'ReportImgProvider', 'ReportProvider',
2426
'ReportLayoutProvider', 'ReportSeriesProvider', 'ReportTasksProvider',
2527
'DockerProvider', 'ModelProvider', 'AuxiliaryProvider',
26-
'TaskSyncedProvider'
28+
'TaskSyncedProvider', 'MemoryProvider', 'SpaceProvider'
2729
]

mlcomp/db/providers/base.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ def add(self, obj: Base, commit=True):
5151
return obj
5252

5353
def bulk_save_objects(
54-
self,
55-
objs,
56-
return_defaults=False,
57-
update_changed_only=True,
58-
preserve_order=True,
54+
self,
55+
objs,
56+
return_defaults=False,
57+
update_changed_only=True,
58+
preserve_order=True,
5959
):
6060
for obj in objs:
6161
adapt_db_types(obj)
@@ -68,8 +68,9 @@ def bulk_save_objects(
6868
)
6969
self._session.commit()
7070

71-
def by_id(self, id: int, joined_load=None):
72-
res = self.query(self.model).filter(getattr(self.model, 'id') == id)
71+
def by_id(self, id: int, joined_load=None, key_column: str = 'id'):
72+
res = self.query(self.model).filter(
73+
getattr(self.model, key_column) == id)
7374
if joined_load is not None:
7475
for n in joined_load:
7576
res = res.options(joinedload(n, innerjoin=True))

mlcomp/db/providers/memory.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from mlcomp.db.core import PaginatorOptions
2+
from mlcomp.db.models import Memory
3+
from mlcomp.db.providers.base import BaseDataProvider
4+
5+
6+
class MemoryProvider(BaseDataProvider):
7+
model = Memory
8+
9+
def get(self, filter: dict, options: PaginatorOptions = None):
10+
query = self.query(Memory)
11+
if filter.get('model'):
12+
query = query.filter(Memory.model.contains(filter['model']))
13+
if filter.get('variant'):
14+
query = query.filter(Memory.variant.contains(filter['variant']))
15+
16+
total = query.count()
17+
paginator = self.paginator(query, options) if options else query
18+
data = []
19+
for p in paginator.all():
20+
item = self.to_dict(p)
21+
data.append(item)
22+
23+
return {
24+
'total': total,
25+
'data': data
26+
}
27+
28+
def find(self, data: dict):
29+
query = self.query(Memory)
30+
for k, v in data.items():
31+
if k in ['batch_size']:
32+
continue
33+
34+
if hasattr(Memory, k):
35+
query = query.filter(getattr(Memory, k) == v)
36+
return query.all()
37+
38+
39+
__all__ = ['MemoryProvider']

mlcomp/db/providers/space.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from sqlalchemy import literal_column
2+
3+
from mlcomp.db.core import PaginatorOptions
4+
from mlcomp.db.models import Space
5+
from mlcomp.db.models.space import SpaceRelation
6+
from mlcomp.db.providers.base import BaseDataProvider
7+
8+
9+
class SpaceProvider(BaseDataProvider):
10+
model = Space
11+
12+
def get(self, filter: dict, options: PaginatorOptions = None):
13+
query = self.query(Space, literal_column('0').label('relation'))
14+
if 'parent' in filter:
15+
query = query.filter(Space.name != filter['parent'])
16+
17+
if filter.get('name'):
18+
query = query.filter(Space.name.contains(filter['name']))
19+
if filter.get('parent'):
20+
relation = literal_column('1').label('relation')
21+
query2 = self.query(Space, relation). \
22+
join(SpaceRelation, SpaceRelation.child == Space.name).filter(
23+
SpaceRelation.parent == filter['parent']
24+
)
25+
query2_names = [space.name for space, _ in query2.all()]
26+
query = query.filter(Space.name.notin_(query2_names))
27+
query = query.union_all(query2)
28+
query = query.order_by(relation.desc())
29+
30+
total = query.count()
31+
paginator = self.paginator(query, options) if options else query
32+
data = []
33+
for space, relation in paginator.all():
34+
item = self.to_dict(space)
35+
item['relation'] = relation
36+
data.append(item)
37+
38+
return {
39+
'total': total,
40+
'data': data
41+
}
42+
43+
def add_relation(self, parent: str, child: str):
44+
self.add(SpaceRelation(parent=parent, child=child))
45+
46+
def remove_relation(self, parent: str, child: str):
47+
self.query(SpaceRelation).filter(
48+
SpaceRelation.parent == parent).filter(
49+
SpaceRelation.child == child).delete(synchronize_session=False)
50+
51+
self.session.commit()
52+
53+
def related(self, parent: str):
54+
res = self.query(Space).join(SpaceRelation,
55+
SpaceRelation.child == Space.name) \
56+
.filter(SpaceRelation.parent == parent) \
57+
.all()
58+
return res
59+
60+
61+
__all__ = ['SpaceProvider']

mlcomp/db/providers/task.py

+8
Original file line numberDiff line numberDiff line change
@@ -315,5 +315,13 @@ def find_dependents(self, task_id: int):
315315
all()
316316
return res
317317

318+
def get_dependencies(self, dag: int):
319+
res = self.query(TaskDependence).join(
320+
Task, Task.id == TaskDependence.task_id).\
321+
filter(Task.dag == dag).\
322+
all()
323+
324+
return res
325+
318326

319327
__all__ = ['TaskProvider']

0 commit comments

Comments
 (0)