Skip to content

Commit ce2dd31

Browse files
committed
Add AM GPU param
1 parent 4a57e1b commit ce2dd31

File tree

5 files changed

+25
-1
lines changed

5 files changed

+25
-1
lines changed

hadoop-plugin-test/expectedJobs/jobs1/jobs1_job24.job

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# This file generated from the Hadoop DSL. Do not edit by hand.
22
type=TensorFlowJob
33
dependencies=jobs1_job23
4+
am_gpus=1
45
am_memory=2048
56
am_vcores=1
67
archive=tensorflow-starter-kit.zip

hadoop-plugin-test/src/main/gradle/positive/jobs1.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ hadoop {
385385
executes 'path/to/python/script.py'
386386
amMemoryMB 2048
387387
amCores 1
388+
amGpus 1
388389
psMemoryMB 2048
389390
psCores 1
390391
workerMemoryMB 8192

hadoop-plugin/src/main/groovy/com/linkedin/gradle/hadoopdsl/job/TensorFlowJob.groovy

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ interface TensorFlowJob {
5757
@HadoopDslMethod
5858
void amCores(int amCores);
5959

60+
/**
61+
* Sets number of GPUs for TensorFlow AM.
62+
*
63+
* @param amGpus Number of GPUs for TensorFlow AM
64+
*/
65+
@HadoopDslMethod
66+
void amGpus(int amGpus);
67+
6068
/**
6169
* Sets memory in MB for each parameter server's YARN container.
6270
* Leaving this unset will default to the Hadoop application's default value.

hadoop-plugin/src/main/groovy/com/linkedin/gradle/hadoopdsl/job/TensorFlowSparkJob.groovy

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ class TensorFlowSparkJob extends SparkJob implements TensorFlowJob {
9090
super.setConf("spark.driver.cores", amCores);
9191
}
9292

93+
@Override
94+
void amGpus(int amGpus) {
95+
throw new Exception("Requesting AM GPUs via tensorflow on spark is not supported currently.");
96+
}
97+
9398
@Override
9499
void psMemoryMB(int psMemoryMB) {
95100
// If worker memory has already been set, ignore
@@ -120,7 +125,7 @@ class TensorFlowSparkJob extends SparkJob implements TensorFlowJob {
120125

121126
@Override
122127
void workerGpus(int workerGpus) {
123-
throw new Exception("Requesting GPUs via tensorflow on spark is not supported currently.");
128+
throw new Exception("Requesting worker GPUs via tensorflow on spark is not supported currently.");
124129
}
125130

126131
@Override

hadoop-plugin/src/main/groovy/com/linkedin/gradle/hadoopdsl/job/TensorFlowTonyJob.groovy

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import com.linkedin.gradle.hadoopdsl.HadoopDslMethod;
3838
* executes path/to/python/script.py
3939
* amMemoryMB 2048
4040
* amCores 1
41+
* amGpus 1
4142
* psMemoryMB 2048
4243
* psCores 1
4344
* workerMemoryMB 8192
@@ -57,6 +58,7 @@ class TensorFlowTonyJob extends HadoopJavaProcessJob implements TensorFlowJob {
5758
String executePath;
5859
int amMemory;
5960
int amCores;
61+
int amGpus;
6062
int psMemory;
6163
int psCores;
6264
int workerMemory;
@@ -93,6 +95,7 @@ class TensorFlowTonyJob extends HadoopJavaProcessJob implements TensorFlowJob {
9395
cloneJob.executePath = executePath;
9496
cloneJob.amMemory = amMemory;
9597
cloneJob.amCores = amCores;
98+
cloneJob.amGpus = amGpus;
9699
cloneJob.psMemory = psMemory;
97100
cloneJob.psCores = psCores;
98101
cloneJob.workerMemory = workerMemory;
@@ -123,6 +126,12 @@ class TensorFlowTonyJob extends HadoopJavaProcessJob implements TensorFlowJob {
123126
setJobProperty("am_vcores", this.amCores);
124127
}
125128

129+
@Override
130+
void amGpus(int amGpus) {
131+
this.amGpus = amGpus;
132+
setJobProperty("am_gpus", this.amGpus);
133+
}
134+
126135
@Override
127136
void psMemoryMB(int psMemoryMB) {
128137
this.psMemory = psMemoryMB;

0 commit comments

Comments
 (0)