Skip to content

Commit d6f4301

Browse files
committed
[DLRM/TF2] Support TensorFlow 2.10
1 parent 72aebe7 commit d6f4301

2 files changed

Lines changed: 13 additions & 6 deletions

File tree

TensorFlow2/Recommendation/DLRM/Dockerfile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ WORKDIR /dlrm
2222

2323
ADD requirements.txt .
2424

25-
RUN pip install -r requirements.txt
25+
RUN pip install --upgrade pip && pip install -r requirements.txt
2626

2727
RUN rm -rf distributed-embeddings &&\
2828
git clone https://github.com/NVIDIA-Merlin/distributed-embeddings.git &&\
2929
cd distributed-embeddings &&\
30-
git checkout 427f869ac &&\
30+
git checkout v0.2 &&\
31+
git submodule init && git submodule update &&\
3132
pip uninstall -y distributed-embeddings &&\
3233
make clean &&\
3334
make pip_pkg -j all &&\

TensorFlow2/Recommendation/DLRM/tensorflow-dot-based-interact/Makefile

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,14 @@ PYTHON_BIN_PATH = python
1919

2020
TF_CFLAGS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))')
2121
TF_LFLAGS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')
22-
23-
CFLAGS = ${TF_CFLAGS} -fPIC -O2 -std=c++14
22+
TF_VERSION := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(int(tf.__version__.split(".")[1]))')
23+
ifeq ($(shell expr $(TF_VERSION) \>= 10), 1)
24+
CPP_STD := 17
25+
else
26+
CPP_STD := 14
27+
endif
28+
29+
CFLAGS = ${TF_CFLAGS} -fPIC -O2 -std=c++${CPP_STD}
2430
LDFLAGS = -shared ${TF_LFLAGS}
2531

2632
.DEFAULT_GOAL := lib
@@ -50,11 +56,11 @@ endif
5056

5157
volta: $(VOLTA_TARGET_OBJECT)
5258
$(VOLTA_TARGET_OBJECT): $(CC_SRC_DIR)/kernels/volta/dot_based_interact_volta.cu
53-
$(NVCC) -std=c++14 -c -o $@ $^ $(TF_CFLAGS) -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -DNDEBUG --expt-relaxed-constexpr -arch=sm_70
59+
$(NVCC) -std=c++${CPP_STD} -c -o $@ $^ $(TF_CFLAGS) -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -DNDEBUG --expt-relaxed-constexpr -arch=sm_70
5460

5561
ampere: $(AMPERE_TARGET_OBJECT)
5662
$(AMPERE_TARGET_OBJECT): $(CC_SRC_DIR)/kernels/ampere/dot_based_interact_ampere.cu
57-
$(NVCC) -std=c++14 -c -o $@ $^ $(TF_CFLAGS) -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -DNDEBUG --expt-relaxed-constexpr -arch=sm_80
63+
$(NVCC) -std=c++${CPP_STD} -c -o $@ $^ $(TF_CFLAGS) -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -DNDEBUG --expt-relaxed-constexpr -arch=sm_80
5864

5965
lib: $(TARGET_LIB)
6066
$(TARGET_LIB): $(CC_SRCS) $(VOLTA_TARGET_OBJECT) $(AMPERE_TARGET_OBJECT)

0 commit comments

Comments
 (0)