Skip to content

Commit 775d58c

Browse files
committed
Add numpy.linspace to openvino backend
1 parent 6e688ab commit 775d58c

File tree

2 files changed

+44
-6
lines changed

2 files changed

+44
-6
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ NumpyDtypeTest::test_isclose
3232
NumpyDtypeTest::test_isfinite
3333
NumpyDtypeTest::test_isinf
3434
NumpyDtypeTest::test_isnan
35-
NumpyDtypeTest::test_linspace
3635
NumpyDtypeTest::test_log10
3736
NumpyDtypeTest::test_log1p
3837
NumpyDtypeTest::test_log
@@ -158,7 +157,6 @@ NumpyTwoInputOpsCorrectnessTest::test_divide_no_nan
158157
NumpyTwoInputOpsCorrectnessTest::test_einsum
159158
NumpyTwoInputOpsCorrectnessTest::test_inner
160159
NumpyTwoInputOpsCorrectnessTest::test_isclose
161-
NumpyTwoInputOpsCorrectnessTest::test_linspace
162160
NumpyTwoInputOpsCorrectnessTest::test_logspace
163161
NumpyTwoInputOpsCorrectnessTest::test_outer
164162
NumpyTwoInputOpsCorrectnessTest::test_quantile

keras/src/backend/openvino/numpy.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -823,11 +823,51 @@ def less_equal(x1, x2):
823823

824824

825825
def linspace(
826-
start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0
826+
start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0
827827
):
828-
raise NotImplementedError(
829-
"`linspace` is not supported with openvino backend"
830-
)
828+
start = get_ov_output(start)
829+
stop = get_ov_output(stop)
830+
831+
if dtype is not None:
832+
ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)]
833+
else:
834+
ov_type = OPENVINO_DTYPES[config.floatx()]
835+
836+
if num < 0:
837+
raise ValueError(f"Number of samples, {num}, must be non-negative.")
838+
elif num == 0:
839+
range_vals = ov_opset.constant([], dtype=ov_type).output(0)
840+
if retstep:
841+
return OpenVINOKerasTensor(range_vals), None
842+
return OpenVINOKerasTensor(range_vals)
843+
elif num == 1:
844+
range_vals = ov_opset.broadcast(
845+
start, ov_opset.constant([1], dtype=Type.i32).output(0)
846+
).output(0)
847+
if retstep:
848+
return OpenVINOKerasTensor(range_vals), None
849+
return OpenVINOKerasTensor(range_vals)
850+
851+
num = ov_opset.constant(num, dtype=Type.i32).output(0)
852+
num = ov_opset.convert(num, ov_type)
853+
854+
if not endpoint:
855+
step = ov_opset.divide(ov_opset.subtract(stop, start), num).output(0)
856+
else:
857+
step = ov_opset.divide(
858+
ov_opset.subtract(stop, start),
859+
ov_opset.subtract(num,
860+
ov_opset.constant(1, dtype=ov_type).output(0))
861+
).output(0)
862+
863+
start = ov_opset.broadcast(start,
864+
ov_opset.shape_of(stop, Type.i32).output(0)
865+
).output(0)
866+
range_vals = ov_opset.range(start, stop, step, ov_type).output(0)
867+
868+
if retstep:
869+
return OpenVINOKerasTensor(range_vals), step
870+
return OpenVINOKerasTensor(range_vals)
831871

832872

833873
def log(x):

0 commit comments

Comments
 (0)