Skip to content

Commit f8cd70f

Browse files
authored
Replaced Env class with dataclass (#277)
1 parent 8fb12af commit f8cd70f

File tree

1 file changed

+48
-170
lines changed

1 file changed

+48
-170
lines changed

ml_service/util/env_variables.py

Lines changed: 48 additions & 170 deletions
Original file line numberDiff line numberDiff line change
@@ -1,174 +1,52 @@
1+
"""Env dataclass to load and hold all environment variables
2+
"""
3+
from dataclasses import dataclass
14
import os
2-
from dotenv import load_dotenv
3-
4-
5-
class Singleton(object):
6-
_instances = {}
7-
8-
def __new__(class_, *args, **kwargs):
9-
if class_ not in class_._instances:
10-
class_._instances[class_] = super(Singleton, class_).__new__(class_, *args, **kwargs) # noqa E501
11-
return class_._instances[class_]
12-
13-
14-
class Env(Singleton):
15-
16-
def __init__(self):
17-
load_dotenv()
18-
self._workspace_name = os.environ.get("WORKSPACE_NAME")
19-
self._resource_group = os.environ.get("RESOURCE_GROUP")
20-
self._subscription_id = os.environ.get("SUBSCRIPTION_ID")
21-
self._tenant_id = os.environ.get("TENANT_ID")
22-
self._app_id = os.environ.get("SP_APP_ID")
23-
self._app_secret = os.environ.get("SP_APP_SECRET")
24-
self._vm_size = os.environ.get("AML_COMPUTE_CLUSTER_CPU_SKU")
25-
self._compute_name = os.environ.get("AML_COMPUTE_CLUSTER_NAME")
26-
self._vm_priority = os.environ.get("AML_CLUSTER_PRIORITY", 'lowpriority') # noqa E501
27-
self._min_nodes = int(os.environ.get("AML_CLUSTER_MIN_NODES", 0))
28-
self._max_nodes = int(os.environ.get("AML_CLUSTER_MAX_NODES", 4))
29-
self._build_id = os.environ.get("BUILD_BUILDID")
30-
self._pipeline_name = os.environ.get("TRAINING_PIPELINE_NAME")
31-
self._sources_directory_train = os.environ.get("SOURCES_DIR_TRAIN")
32-
self._train_script_path = os.environ.get("TRAIN_SCRIPT_PATH")
33-
self._evaluate_script_path = os.environ.get("EVALUATE_SCRIPT_PATH")
34-
self._register_script_path = os.environ.get("REGISTER_SCRIPT_PATH")
35-
self._model_name = os.environ.get("MODEL_NAME")
36-
self._experiment_name = os.environ.get("EXPERIMENT_NAME")
37-
self._model_version = os.environ.get('MODEL_VERSION')
38-
self._image_name = os.environ.get('IMAGE_NAME')
39-
self._db_cluster_id = os.environ.get("DB_CLUSTER_ID")
40-
self._score_script = os.environ.get("SCORE_SCRIPT")
41-
self._build_uri = os.environ.get("BUILD_URI")
42-
self._dataset_name = os.environ.get("DATASET_NAME")
43-
self._datastore_name = os.environ.get("DATASTORE_NAME")
44-
self._dataset_version = os.environ.get("DATASET_VERSION")
45-
self._run_evaluation = os.environ.get("RUN_EVALUATION", "true")
46-
self._allow_run_cancel = os.environ.get(
47-
"ALLOW_RUN_CANCEL", "true")
48-
self._aml_env_name = os.environ.get("AML_ENV_NAME")
49-
self._rebuild_env = os.environ.get("AML_REBUILD_ENVIRONMENT",
50-
"false").lower().strip() == "true"
51-
52-
@property
53-
def workspace_name(self):
54-
return self._workspace_name
55-
56-
@property
57-
def resource_group(self):
58-
return self._resource_group
59-
60-
@property
61-
def subscription_id(self):
62-
return self._subscription_id
63-
64-
@property
65-
def tenant_id(self):
66-
return self._tenant_id
67-
68-
@property
69-
def app_id(self):
70-
return self._app_id
71-
72-
@property
73-
def app_secret(self):
74-
return self._app_secret
75-
76-
@property
77-
def vm_size(self):
78-
return self._vm_size
79-
80-
@property
81-
def compute_name(self):
82-
return self._compute_name
83-
84-
@property
85-
def db_cluster_id(self):
86-
return self._db_cluster_id
87-
88-
@property
89-
def build_id(self):
90-
return self._build_id
91-
92-
@property
93-
def pipeline_name(self):
94-
return self._pipeline_name
5+
from typing import Optional
956

96-
@property
97-
def sources_directory_train(self):
98-
return self._sources_directory_train
99-
100-
@property
101-
def train_script_path(self):
102-
return self._train_script_path
103-
104-
@property
105-
def evaluate_script_path(self):
106-
return self._evaluate_script_path
107-
108-
@property
109-
def register_script_path(self):
110-
return self._register_script_path
111-
112-
@property
113-
def model_name(self):
114-
return self._model_name
115-
116-
@property
117-
def experiment_name(self):
118-
return self._experiment_name
119-
120-
@property
121-
def vm_priority(self):
122-
return self._vm_priority
123-
124-
@property
125-
def min_nodes(self):
126-
return self._min_nodes
127-
128-
@property
129-
def max_nodes(self):
130-
return self._max_nodes
131-
132-
@property
133-
def model_version(self):
134-
return self._model_version
135-
136-
@property
137-
def image_name(self):
138-
return self._image_name
139-
140-
@property
141-
def score_script(self):
142-
return self._score_script
143-
144-
@property
145-
def build_uri(self):
146-
return self._build_uri
147-
148-
@property
149-
def dataset_name(self):
150-
return self._dataset_name
151-
152-
@property
153-
def datastore_name(self):
154-
return self._datastore_name
155-
156-
@property
157-
def dataset_version(self):
158-
return self._dataset_version
159-
160-
@property
161-
def run_evaluation(self):
162-
return self._run_evaluation
163-
164-
@property
165-
def allow_run_cancel(self):
166-
return self._allow_run_cancel
7+
from dotenv import load_dotenv
1678

168-
@property
169-
def aml_env_name(self):
170-
return self._aml_env_name
1719

172-
@property
173-
def rebuild_env(self):
174-
return self._rebuild_env
10+
@dataclass(frozen=True)
11+
class Env:
12+
"""Loads all environment variables into a predefined set of properties
13+
"""
14+
# to load .env file into environment variables for local execution
15+
load_dotenv()
16+
workspace_name: Optional[str] = os.environ.get("WORKSPACE_NAME")
17+
resource_group: Optional[str] = os.environ.get("RESOURCE_GROUP")
18+
subscription_id: Optional[str] = os.environ.get("SUBSCRIPTION_ID")
19+
tenant_id: Optional[str] = os.environ.get("TENANT_ID")
20+
app_id: Optional[str] = os.environ.get("SP_APP_ID")
21+
app_secret: Optional[str] = os.environ.get("SP_APP_SECRET")
22+
vm_size: Optional[str] = os.environ.get("AML_COMPUTE_CLUSTER_CPU_SKU")
23+
compute_name: Optional[str] = os.environ.get("AML_COMPUTE_CLUSTER_NAME")
24+
vm_priority: Optional[str] = os.environ.get("AML_CLUSTER_PRIORITY",
25+
'lowpriority')
26+
min_nodes: int = int(os.environ.get("AML_CLUSTER_MIN_NODES", 0))
27+
max_nodes: int = int(os.environ.get("AML_CLUSTER_MAX_NODES", 4))
28+
build_id: Optional[str] = os.environ.get("BUILD_BUILDID")
29+
pipeline_name: Optional[str] = os.environ.get("TRAINING_PIPELINE_NAME")
30+
sources_directory_train: Optional[str] = os.environ.get(
31+
"SOURCES_DIR_TRAIN")
32+
train_script_path: Optional[str] = os.environ.get("TRAIN_SCRIPT_PATH")
33+
evaluate_script_path: Optional[str] = os.environ.get(
34+
"EVALUATE_SCRIPT_PATH")
35+
register_script_path: Optional[str] = os.environ.get(
36+
"REGISTER_SCRIPT_PATH")
37+
model_name: Optional[str] = os.environ.get("MODEL_NAME")
38+
experiment_name: Optional[str] = os.environ.get("EXPERIMENT_NAME")
39+
model_version: Optional[str] = os.environ.get('MODEL_VERSION')
40+
image_name: Optional[str] = os.environ.get('IMAGE_NAME')
41+
db_cluster_id: Optional[str] = os.environ.get("DB_CLUSTER_ID")
42+
score_script: Optional[str] = os.environ.get("SCORE_SCRIPT")
43+
build_uri: Optional[str] = os.environ.get("BUILD_URI")
44+
dataset_name: Optional[str] = os.environ.get("DATASET_NAME")
45+
datastore_name: Optional[str] = os.environ.get("DATASTORE_NAME")
46+
dataset_version: Optional[str] = os.environ.get("DATASET_VERSION")
47+
run_evaluation: Optional[str] = os.environ.get("RUN_EVALUATION", "true")
48+
allow_run_cancel: Optional[str] = os.environ.get("ALLOW_RUN_CANCEL",
49+
"true")
50+
aml_env_name: Optional[str] = os.environ.get("AML_ENV_NAME")
51+
rebuild_env: Optional[bool] = os.environ.get(
52+
"AML_REBUILD_ENVIRONMENT", "false").lower().strip() == "true"

0 commit comments

Comments
 (0)