load(
"@rules_java//java:defs.bzl",
java_test_impl = "java_test"
)
load(
"@io_bazel_rules_scala//scala:scala.bzl",
scala_test_impl = "scala_test"
)
def junit_tests(name, **kwargs):
sources = kwargs.get('srcs')
dependencies = kwargs.get('deps')
if kwargs.get('sources'):
sources = kwargs.get('sources')
if kwargs.get('dependencies'):
dependencies = process_deps(kwargs.get('dependencies'))
runtime_deps = []
deps = []
if sources==None or sources == []:
runtime_deps = dependencies
else:
deps = dependencies
jvm_options = kwargs.get('extra_jvm_options', None)
kwargs.pop('extra_jvm_options', None)
java_src = False
scala_src = False
splits = sources[0].split(".")
if splits[1] == 'java':
java_src = True
if splits[1] == 'scala':
scala_src = True
if java_src:
java_test_impl(
name = name,
srcs = sources,
jvm_flags = jvm_options,
visibility = ["//visibility:public"],
)
if scala_src:
# print("name={}, srcs={}, deps={}, runtime_deps={}".format(name, sources, deps,runtime_deps))
scala_test_impl(
name = name,
srcs = sources,
deps = deps,
runtime_deps = runtime_deps,
jvm_flags = jvm_options,
visibility = ["//visibility:public"],
)
def process_deps(dependencies):
return [d if d.startswith("@maven") or d.startswith(":") else "//"+d for d in dependencies]