Skip to content

Commit 7e46ae9

Browse files
CEL Dev Teamcopybara-github
authored andcommitted
Add replace function to CEL strings extension.
PiperOrigin-RevId: 688180534
1 parent bd7a08b commit 7e46ae9

3 files changed

Lines changed: 182 additions & 1 deletion

File tree

extensions/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ cc_test(
365365
"//runtime:runtime_builder",
366366
"//runtime:runtime_options",
367367
"//runtime:standard_runtime_builder_factory",
368+
"@com_google_absl//absl/status:status_matchers",
368369
"@com_google_absl//absl/strings:cord",
369370
"@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto",
370371
],

extensions/strings.cc

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,51 @@ absl::StatusOr<Value> LowerAscii(ValueManager& value_manager,
216216
return value_manager.CreateUncheckedStringValue(std::move(content));
217217
}
218218

219+
absl::StatusOr<Value> Replace2(ValueManager& value_manager,
220+
const StringValue& string,
221+
const StringValue& old_sub,
222+
const StringValue& new_sub, int64_t limit) {
223+
if (limit == 0) {
224+
// When the replacement limit is 0, the result is the original string.
225+
return string;
226+
}
227+
if (limit < 0) {
228+
// Per spec, when limit is negative treat is as unlimited.
229+
limit = std::numeric_limits<int64_t>::max();
230+
}
231+
232+
std::string result;
233+
std::string old_sub_scratch;
234+
absl::string_view old_sub_view = old_sub.NativeString(old_sub_scratch);
235+
std::string new_sub_scratch;
236+
absl::string_view new_sub_view = new_sub.NativeString(new_sub_scratch);
237+
std::string content_scratch;
238+
absl::string_view content_view = string.NativeString(content_scratch);
239+
while (limit > 0 && !content_view.empty()) {
240+
auto pos = content_view.find(old_sub_view);
241+
if (pos == absl::string_view::npos) {
242+
break;
243+
}
244+
result.append(content_view.substr(0, pos));
245+
result.append(new_sub_view);
246+
--limit;
247+
content_view.remove_prefix(pos + old_sub_view.size());
248+
}
249+
// Add the remainder of the string.
250+
if (!content_view.empty()) {
251+
result.append(content_view);
252+
}
253+
254+
return value_manager.CreateUncheckedStringValue(std::move(result));
255+
}
256+
257+
absl::StatusOr<Value> Replace1(ValueManager& value_manager,
258+
const StringValue& string,
259+
const StringValue& old_sub,
260+
const StringValue& new_sub) {
261+
return Replace2(value_manager, string, old_sub, new_sub, -1);
262+
}
263+
219264
} // namespace
220265

221266
absl::Status RegisterStringsFunctions(FunctionRegistry& registry,
@@ -246,6 +291,18 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry,
246291
CreateDescriptor("lowerAscii", /*receiver_style=*/true),
247292
UnaryFunctionAdapter<absl::StatusOr<Value>, StringValue>::WrapFunction(
248293
LowerAscii)));
294+
CEL_RETURN_IF_ERROR(registry.Register(
295+
VariadicFunctionAdapter<
296+
absl::StatusOr<Value>, StringValue, StringValue,
297+
StringValue>::CreateDescriptor("replace", /*receiver_style=*/true),
298+
VariadicFunctionAdapter<absl::StatusOr<Value>, StringValue, StringValue,
299+
StringValue>::WrapFunction(Replace1)));
300+
CEL_RETURN_IF_ERROR(registry.Register(
301+
VariadicFunctionAdapter<
302+
absl::StatusOr<Value>, StringValue, StringValue, StringValue,
303+
int64_t>::CreateDescriptor("replace", /*receiver_style=*/true),
304+
VariadicFunctionAdapter<absl::StatusOr<Value>, StringValue, StringValue,
305+
StringValue, int64_t>::WrapFunction(Replace2)));
249306
return absl::OkStatus();
250307
}
251308

extensions/strings_test.cc

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <utility>
1919

2020
#include "google/api/expr/v1alpha1/syntax.pb.h"
21+
#include "absl/status/status_matchers.h"
2122
#include "absl/strings/cord.h"
2223
#include "common/memory.h"
2324
#include "common/value.h"
@@ -36,6 +37,7 @@
3637
namespace cel::extensions {
3738
namespace {
3839

40+
using ::absl_testing::IsOk;
3941
using ::google::api::expr::v1alpha1::ParsedExpr;
4042
using ::google::api::expr::parser::Parse;
4143
using ::google::api::expr::parser::ParserOptions;
@@ -46,7 +48,8 @@ TEST(Strings, SplitWithEmptyDelimiterCord) {
4648
ASSERT_OK_AND_ASSIGN(auto builder,
4749
CreateStandardRuntimeBuilder(
4850
internal::GetTestingDescriptorPool(), options));
49-
EXPECT_OK(RegisterStringsFunctions(builder.function_registry(), options));
51+
EXPECT_THAT(RegisterStringsFunctions(builder.function_registry(), options),
52+
IsOk());
5053

5154
ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build());
5255

@@ -71,5 +74,125 @@ TEST(Strings, SplitWithEmptyDelimiterCord) {
7174
EXPECT_TRUE(result.GetBool().NativeValue());
7275
}
7376

77+
TEST(Strings, Replace) {
78+
MemoryManagerRef memory_manager = MemoryManagerRef::ReferenceCounting();
79+
const auto options = RuntimeOptions{};
80+
ASSERT_OK_AND_ASSIGN(auto builder,
81+
CreateStandardRuntimeBuilder(
82+
internal::GetTestingDescriptorPool(), options));
83+
EXPECT_THAT(RegisterStringsFunctions(builder.function_registry(), options),
84+
IsOk());
85+
86+
ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build());
87+
88+
ASSERT_OK_AND_ASSIGN(ParsedExpr expr,
89+
Parse("foo.replace('he', 'we') == 'wello wello'",
90+
"<input>", ParserOptions{}));
91+
92+
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Program> program,
93+
ProtobufRuntimeAdapter::CreateProgram(*runtime, expr));
94+
95+
common_internal::LegacyValueManager value_factory(memory_manager,
96+
runtime->GetTypeProvider());
97+
98+
Activation activation;
99+
activation.InsertOrAssignValue("foo", StringValue{absl::Cord("hello hello")});
100+
101+
ASSERT_OK_AND_ASSIGN(Value result,
102+
program->Evaluate(activation, value_factory));
103+
ASSERT_TRUE(result.Is<BoolValue>());
104+
EXPECT_TRUE(result.GetBool().NativeValue());
105+
}
106+
107+
TEST(Strings, ReplaceWithNegativeLimit) {
108+
MemoryManagerRef memory_manager = MemoryManagerRef::ReferenceCounting();
109+
const auto options = RuntimeOptions{};
110+
ASSERT_OK_AND_ASSIGN(auto builder,
111+
CreateStandardRuntimeBuilder(
112+
internal::GetTestingDescriptorPool(), options));
113+
EXPECT_THAT(RegisterStringsFunctions(builder.function_registry(), options),
114+
IsOk());
115+
116+
ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build());
117+
118+
ASSERT_OK_AND_ASSIGN(ParsedExpr expr,
119+
Parse("foo.replace('he', 'we', -1) == 'wello wello'",
120+
"<input>", ParserOptions{}));
121+
122+
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Program> program,
123+
ProtobufRuntimeAdapter::CreateProgram(*runtime, expr));
124+
125+
common_internal::LegacyValueManager value_factory(memory_manager,
126+
runtime->GetTypeProvider());
127+
128+
Activation activation;
129+
activation.InsertOrAssignValue("foo", StringValue{absl::Cord("hello hello")});
130+
131+
ASSERT_OK_AND_ASSIGN(Value result,
132+
program->Evaluate(activation, value_factory));
133+
ASSERT_TRUE(result.Is<BoolValue>());
134+
EXPECT_TRUE(result.GetBool().NativeValue());
135+
}
136+
137+
TEST(Strings, ReplaceWithLimit) {
138+
MemoryManagerRef memory_manager = MemoryManagerRef::ReferenceCounting();
139+
const auto options = RuntimeOptions{};
140+
ASSERT_OK_AND_ASSIGN(auto builder,
141+
CreateStandardRuntimeBuilder(
142+
internal::GetTestingDescriptorPool(), options));
143+
EXPECT_THAT(RegisterStringsFunctions(builder.function_registry(), options),
144+
IsOk());
145+
146+
ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build());
147+
148+
ASSERT_OK_AND_ASSIGN(ParsedExpr expr,
149+
Parse("foo.replace('he', 'we', 1) == 'wello hello'",
150+
"<input>", ParserOptions{}));
151+
152+
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Program> program,
153+
ProtobufRuntimeAdapter::CreateProgram(*runtime, expr));
154+
155+
common_internal::LegacyValueManager value_factory(memory_manager,
156+
runtime->GetTypeProvider());
157+
158+
Activation activation;
159+
activation.InsertOrAssignValue("foo", StringValue{absl::Cord("hello hello")});
160+
161+
ASSERT_OK_AND_ASSIGN(Value result,
162+
program->Evaluate(activation, value_factory));
163+
ASSERT_TRUE(result.Is<BoolValue>());
164+
EXPECT_TRUE(result.GetBool().NativeValue());
165+
}
166+
167+
TEST(Strings, ReplaceWithZeroLimit) {
168+
MemoryManagerRef memory_manager = MemoryManagerRef::ReferenceCounting();
169+
const auto options = RuntimeOptions{};
170+
ASSERT_OK_AND_ASSIGN(auto builder,
171+
CreateStandardRuntimeBuilder(
172+
internal::GetTestingDescriptorPool(), options));
173+
EXPECT_THAT(RegisterStringsFunctions(builder.function_registry(), options),
174+
IsOk());
175+
176+
ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build());
177+
178+
ASSERT_OK_AND_ASSIGN(ParsedExpr expr,
179+
Parse("foo.replace('he', 'we', 0) == 'hello hello'",
180+
"<input>", ParserOptions{}));
181+
182+
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Program> program,
183+
ProtobufRuntimeAdapter::CreateProgram(*runtime, expr));
184+
185+
common_internal::LegacyValueManager value_factory(memory_manager,
186+
runtime->GetTypeProvider());
187+
188+
Activation activation;
189+
activation.InsertOrAssignValue("foo", StringValue{absl::Cord("hello hello")});
190+
191+
ASSERT_OK_AND_ASSIGN(Value result,
192+
program->Evaluate(activation, value_factory));
193+
ASSERT_TRUE(result.Is<BoolValue>());
194+
EXPECT_TRUE(result.GetBool().NativeValue());
195+
}
196+
74197
} // namespace
75198
} // namespace cel::extensions

0 commit comments

Comments
 (0)