Skip to content

Commit 8f21ff9

Browse files
authored
[MLIR][DLTI][Transform] Introduce transform.dlti.query (#101561)
This transform op makes it possible to query attributes associated to IR by means of the DLTI dialect. The op takes both a `key` and a target `op` to perform the query at. Facility functions automatically find the closest ancestor op which defines the appropriate DLTI interface or has an attribute implementing a DLTI interface. By default the lookup uses the data layout interfaces of DLTI. If the optional `device` parameter is provided, the lookup happens with respect to the interfaces for TargetSystemSpec and TargetDeviceSpec. This op uses new free-standing functions in the `dlti` namespace to not only look up specifications via the `DataLayoutSpecOpInterface` and on `ModuleOp`s but also on any ancestor op that has an appropriate DLTI attribute.
1 parent 52220c2 commit 8f21ff9

File tree

11 files changed

+563
-0
lines changed

11 files changed

+563
-0
lines changed

mlir/include/mlir/Dialect/DLTI/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
add_subdirectory(TransformOps)
2+
13
add_mlir_dialect(DLTI dlti)
24
add_mlir_doc(DLTIAttrs DLTIDialect Dialects/ -gen-dialect-doc)
35

mlir/include/mlir/Dialect/DLTI/DLTI.h

+13
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,19 @@ namespace detail {
2222
class DataLayoutEntryAttrStorage;
2323
} // namespace detail
2424
} // namespace mlir
25+
namespace mlir {
26+
namespace dlti {
27+
/// Find the first DataLayoutSpec associated to `op`, via either the
28+
/// DataLayoutOpInterface, a method on ModuleOp, or an attribute implementing
29+
/// the interface, on `op` and else on `op`'s ancestors in turn.
30+
DataLayoutSpecInterface getDataLayoutSpec(Operation *op);
31+
32+
/// Find the first TargetSystemSpec associated to `op`, via either the
33+
/// DataLayoutOpInterface, a method on ModuleOp, or an attribute implementing
34+
/// the interface, on `op` and else on `op`'s ancestors in turn.
35+
TargetSystemSpecInterface getTargetSystemSpec(Operation *op);
36+
} // namespace dlti
37+
} // namespace mlir
2538

2639
#define GET_ATTRDEF_CLASSES
2740
#include "mlir/Dialect/DLTI/DLTIAttrs.h.inc"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(LLVM_TARGET_DEFINITIONS DLTITransformOps.td)
2+
mlir_tablegen(DLTITransformOps.h.inc -gen-op-decls)
3+
mlir_tablegen(DLTITransformOps.cpp.inc -gen-op-defs)
4+
add_public_tablegen_target(MLIRDLTITransformOpsIncGen)
5+
6+
add_mlir_doc(DLTITransformOps DLTITransformOps Dialects/ -gen-op-doc)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//===- DLTITransformOps.h - DLTI transform ops ------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_DLTI_TRANSFORMOPS_DLTITRANSFORMOPS_H
10+
#define MLIR_DIALECT_DLTI_TRANSFORMOPS_DLTITRANSFORMOPS_H
11+
12+
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
13+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
14+
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
15+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
16+
17+
namespace mlir {
18+
namespace transform {
19+
class QueryOp;
20+
} // namespace transform
21+
} // namespace mlir
22+
23+
namespace mlir {
24+
class DialectRegistry;
25+
26+
namespace dlti {
27+
void registerTransformDialectExtension(DialectRegistry &registry);
28+
} // namespace dlti
29+
} // namespace mlir
30+
31+
////===----------------------------------------------------------------------===//
32+
//// DLTI Transform Operations
33+
////===----------------------------------------------------------------------===//
34+
35+
#define GET_OP_CLASSES
36+
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h.inc"
37+
38+
#endif // MLIR_DIALECT_DLTI_TRANSFORMOPS_DLTITRANSFORMOPS_H
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
//===- DLTITransformOps.td - DLTI transform ops ------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef DLTI_TRANSFORM_OPS
10+
#define DLTI_TRANSFORM_OPS
11+
12+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
13+
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
14+
include "mlir/Dialect/Transform/IR/TransformTypes.td"
15+
include "mlir/Interfaces/SideEffectInterfaces.td"
16+
include "mlir/IR/OpBase.td"
17+
18+
def QueryOp : Op<Transform_Dialect, "dlti.query", [
19+
TransformOpInterface, TransformEachOpTrait,
20+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
21+
]> {
22+
let summary = "Return attribute (as param) associated to key via DTLI";
23+
let description = [{
24+
This op queries data layout and target information associated to payload
25+
IR by way of the DLTI dialect. A lookup is performed for the given `key`
26+
at the `target` op, with the DLTI dialect determining which interfaces and
27+
attributes are consulted - first checking `target` and then its ancestors.
28+
29+
When only `key` is provided, the lookup occurs with respect to the data
30+
layout specification of DLTI. When `device` is provided, the lookup occurs
31+
with respect to DLTI's target device specifications associated to a DLTI
32+
system device specification.
33+
34+
#### Return modes
35+
36+
When succesful, the result, `associated_attr`, associates one attribute as a
37+
param for each op in `target`'s payload.
38+
39+
If the lookup fails - as DLTI specifications or entries with the right
40+
names are missing (i.e. the values of `device` and `key`) - a definite
41+
failure is returned.
42+
}];
43+
44+
let arguments = (ins TransformHandleTypeInterface:$target,
45+
OptionalAttr<StrAttr>:$device,
46+
StrAttr:$key);
47+
let results = (outs TransformParamTypeInterface:$associated_attr);
48+
let assemblyFormat =
49+
"(`:``:` $device^ `:``:`)? $key `at` $target attr-dict `:`"
50+
"functional-type(operands, results)";
51+
52+
let extraClassDeclaration = [{
53+
::mlir::DiagnosedSilenceableFailure applyToOne(
54+
::mlir::transform::TransformRewriter &rewriter,
55+
::mlir::Operation *target,
56+
::mlir::transform::ApplyToEachResultList &results,
57+
TransformState &state);
58+
}];
59+
}
60+
61+
#endif // DLTI_TRANSFORM_OPS

mlir/include/mlir/InitAllExtensions.h

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
2626
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
2727
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
28+
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
2829
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
2930
#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
3031
#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
@@ -69,6 +70,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
6970
// Register all transform dialect extensions.
7071
affine::registerTransformDialectExtension(registry);
7172
bufferization::registerTransformDialectExtension(registry);
73+
dlti::registerTransformDialectExtension(registry);
7274
func::registerTransformDialectExtension(registry);
7375
gpu::registerTransformDialectExtension(registry);
7476
linalg::registerTransformDialectExtension(registry);

mlir/lib/Dialect/DLTI/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
add_subdirectory(TransformOps)
12
add_mlir_dialect_library(MLIRDLTIDialect
23
DLTI.cpp
34
Traits.cpp

mlir/lib/Dialect/DLTI/DLTI.cpp

+35
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,41 @@ TargetSystemSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
393393
// DLTIDialect
394394
//===----------------------------------------------------------------------===//
395395

396+
DataLayoutSpecInterface dlti::getDataLayoutSpec(Operation *op) {
397+
DataLayoutSpecInterface dlSpec = nullptr;
398+
399+
for (Operation *cur = op; cur && !dlSpec; cur = cur->getParentOp()) {
400+
if (auto dataLayoutOp = dyn_cast<DataLayoutOpInterface>(cur))
401+
dlSpec = dataLayoutOp.getDataLayoutSpec();
402+
else if (auto moduleOp = dyn_cast<ModuleOp>(cur))
403+
dlSpec = moduleOp.getDataLayoutSpec();
404+
else
405+
for (NamedAttribute attr : cur->getAttrs())
406+
if ((dlSpec = llvm::dyn_cast<DataLayoutSpecInterface>(attr.getValue())))
407+
break;
408+
}
409+
410+
return dlSpec;
411+
}
412+
413+
TargetSystemSpecInterface dlti::getTargetSystemSpec(Operation *op) {
414+
TargetSystemSpecInterface sysSpec = nullptr;
415+
416+
for (Operation *cur = op; cur && !sysSpec; cur = cur->getParentOp()) {
417+
if (auto dataLayoutOp = dyn_cast<DataLayoutOpInterface>(cur))
418+
sysSpec = dataLayoutOp.getTargetSystemSpec();
419+
else if (auto moduleOp = dyn_cast<ModuleOp>(cur))
420+
sysSpec = moduleOp.getTargetSystemSpec();
421+
else
422+
for (NamedAttribute attr : cur->getAttrs())
423+
if ((sysSpec =
424+
llvm::dyn_cast<TargetSystemSpecInterface>(attr.getValue())))
425+
break;
426+
}
427+
428+
return sysSpec;
429+
}
430+
396431
constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutAttrName;
397432
constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessKey;
398433
constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessBig;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
add_mlir_dialect_library(MLIRDLTITransformOps
2+
DLTITransformOps.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/DLTI/TransformOps
6+
7+
DEPENDS
8+
MLIRDLTITransformOpsIncGen
9+
MLIRDLTIDialect
10+
11+
LINK_LIBS PUBLIC
12+
MLIRDLTIDialect
13+
MLIRSideEffectInterfaces
14+
MLIRTransformDialect
15+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
2+
//===- DLTITransformOps.cpp - Implementation of DLTI transform ops --------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
11+
12+
#include "mlir/Dialect/DLTI/DLTI.h"
13+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
14+
#include "mlir/Dialect/Transform/Utils/Utils.h"
15+
#include "mlir/Interfaces/DataLayoutInterfaces.h"
16+
17+
using namespace mlir;
18+
using namespace mlir::transform;
19+
20+
#define DEBUG_TYPE "dlti-transforms"
21+
22+
//===----------------------------------------------------------------------===//
23+
// QueryOp
24+
//===----------------------------------------------------------------------===//
25+
26+
void transform::QueryOp::getEffects(
27+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
28+
onlyReadsHandle(getTargetMutable(), effects);
29+
producesHandle(getOperation()->getOpResults(), effects);
30+
onlyReadsPayload(effects);
31+
}
32+
33+
DiagnosedSilenceableFailure transform::QueryOp::applyToOne(
34+
transform::TransformRewriter &rewriter, Operation *target,
35+
transform::ApplyToEachResultList &results, TransformState &state) {
36+
StringAttr deviceId = getDeviceAttr();
37+
StringAttr key = getKeyAttr();
38+
39+
DataLayoutEntryInterface entry;
40+
if (deviceId) {
41+
TargetSystemSpecInterface sysSpec = dlti::getTargetSystemSpec(target);
42+
if (!sysSpec)
43+
return mlir::emitDefiniteFailure(target->getLoc())
44+
<< "no target system spec associated to: " << target;
45+
46+
if (auto targetSpec = sysSpec.getDeviceSpecForDeviceID(deviceId))
47+
entry = targetSpec->getSpecForIdentifier(key);
48+
else
49+
return mlir::emitDefiniteFailure(target->getLoc())
50+
<< "no " << deviceId << " target device spec found";
51+
} else {
52+
DataLayoutSpecInterface dlSpec = dlti::getDataLayoutSpec(target);
53+
if (!dlSpec)
54+
return mlir::emitDefiniteFailure(target->getLoc())
55+
<< "no data layout spec associated to: " << target;
56+
57+
entry = dlSpec.getSpecForIdentifier(key);
58+
}
59+
60+
if (!entry)
61+
return mlir::emitDefiniteFailure(target->getLoc())
62+
<< "no DLTI entry for key: " << key;
63+
64+
results.push_back(entry.getValue());
65+
66+
return DiagnosedSilenceableFailure::success();
67+
}
68+
69+
//===----------------------------------------------------------------------===//
70+
// Transform op registration
71+
//===----------------------------------------------------------------------===//
72+
73+
namespace {
74+
class DLTITransformDialectExtension
75+
: public transform::TransformDialectExtension<
76+
DLTITransformDialectExtension> {
77+
public:
78+
using Base::Base;
79+
80+
void init() {
81+
registerTransformOps<
82+
#define GET_OP_LIST
83+
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.cpp.inc"
84+
>();
85+
}
86+
};
87+
} // namespace
88+
89+
#define GET_OP_CLASSES
90+
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.cpp.inc"
91+
92+
void mlir::dlti::registerTransformDialectExtension(DialectRegistry &registry) {
93+
registry.addExtensions<DLTITransformDialectExtension>();
94+
}

0 commit comments

Comments
 (0)