Skip to content

[OM] Add field_locs array attribute for ClassFieldsOp locations #8439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions include/circt/Dialect/OM/OMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,9 @@ def ClassOp : OMClassLike<"class", [
// the field location will break.
// This is required because MLIR's FusedLoc uses a "set" semantics where a
// single location is used to represent multiple fields with the same
// location. The OM implementation uses the metadata attribute on FusedLoc
// to store the original array of locations, so that the specific location
// of a field may be easily retrieved by index using the
// `getFieldLocByIndex` API.
// location. The OM implementation uses an attribute to store the original
// array of locations, so that the specific location of a field may be
// easily retrieved by index using the `getFieldLocByIndex` API.
void addNewFieldsOp(mlir::OpBuilder &builder, mlir::ArrayRef<mlir::Location>
locs, mlir::ArrayRef<mlir::Value> values);

Expand All @@ -156,8 +155,14 @@ def ClassOp : OMClassLike<"class", [

def ClassFieldsOp : OMOp<"class.fields", [Terminator, ReturnLike, Pure,
HasParent<"ClassOp">]> {
let arguments = (ins Variadic<AnyType>:$fields);
let assemblyFormat = "attr-dict ($fields^ `:` qualified(type($fields)))?";
let arguments = (ins Variadic<AnyType>:$fields,
OptionalAttr<LocationArrayAttr>:$field_locs);
let assemblyFormat = [{
attr-dict ($fields^ `:` qualified(type($fields)))?
custom<FieldLocs>($field_locs)
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1904,7 +1904,8 @@ struct ClassFieldsOpConversion : public OpConversionPattern<ClassFieldsOp> {
LogicalResult
matchAndRewrite(ClassFieldsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ClassFieldsOp>(op, adaptor.getOperands());
rewriter.replaceOpWithNewOp<ClassFieldsOp>(op, adaptor.getOperands(),
adaptor.getFieldLocsAttr());
return success();
}
};
Expand Down
78 changes: 58 additions & 20 deletions lib/Dialect/OM/OMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using namespace mlir;
using namespace circt::om;

//===----------------------------------------------------------------------===//
// Path Printers and Parsers
// Custom Printers and Parsers
//===----------------------------------------------------------------------===//

static ParseResult parseBasePathString(OpAsmParser &parser, PathAttr &path) {
Expand Down Expand Up @@ -74,6 +74,26 @@ static void printPathString(OpAsmPrinter &p, Operation *op, PathAttr path,
p << '\"';
}

static ParseResult parseFieldLocs(OpAsmParser &parser, ArrayAttr &fieldLocs) {
if (parser.parseOptionalKeyword("field_locs"))
return success();
if (parser.parseLParen() || parser.parseAttribute(fieldLocs) ||
parser.parseRParen()) {
return failure();
}
return success();
}

static void printFieldLocs(OpAsmPrinter &printer, Operation *op,
ArrayAttr fieldLocs) {
mlir::OpPrintingFlags flags;
if (!flags.shouldPrintDebugInfo() || !fieldLocs)
return;
printer << "field_locs(";
printer.printAttribute(fieldLocs);
printer << ")";
}

//===----------------------------------------------------------------------===//
// Shared definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -291,10 +311,16 @@ circt::om::ClassOp circt::om::ClassOp::buildSimpleClassOp(
Block *body = &classOp.getRegion().emplaceBlock();
auto prevLoc = odsBuilder.saveInsertionPoint();
odsBuilder.setInsertionPointToEnd(body);

mlir::SmallVector<Attribute> locAttrs(fieldNames.size(), LocationAttr(loc));

odsBuilder.create<ClassFieldsOp>(
loc, llvm::map_to_vector(fieldTypes, [&](Type type) -> Value {
return body->addArgument(type, loc);
}));
loc,
llvm::map_to_vector(
fieldTypes,
[&](Type type) -> Value { return body->addArgument(type, loc); }),
odsBuilder.getArrayAttr(locAttrs));

odsBuilder.restoreInsertionPoint(prevLoc);

return classOp;
Expand Down Expand Up @@ -409,31 +435,25 @@ void circt::om::ClassOp::addNewFieldsOp(mlir::OpBuilder &builder,
mlir::ArrayRef<Value> values) {
// Store the original locations as a metadata array so that unique locations
// are preserved as a mapping from field index to location
assert(locs.size() == values.size() && "Expected a location per value");
mlir::SmallVector<Attribute> locAttrs;
for (auto loc : locs) {
locAttrs.push_back(cast<Attribute>(LocationAttr(loc)));
}
// Also store the locations incase there's some other analysis that might
// be able to use the default FusedLoc representation.
builder.create<ClassFieldsOp>(
builder.getFusedLoc(locs, builder.getArrayAttr(locAttrs)), values);
builder.create<ClassFieldsOp>(builder.getFusedLoc(locs), values,
builder.getArrayAttr(locAttrs));
}

mlir::Location circt::om::ClassOp::getFieldLocByIndex(size_t i) {
Location loc = this->getFieldsOp()->getLoc();
if (auto locs = dyn_cast<FusedLoc>(loc)) {
// Because it's possible for a user to construct a fields op directly and
// place a FusedLoc that doersn't follow the storage format of
// addNewFieldsOp, we assert the information has been stored appropriately
ArrayAttr metadataArr = dyn_cast<ArrayAttr>(locs.getMetadata());
assert(metadataArr && "Expected fused loc to store metadata array");
assert(i < metadataArr.size() &&
"expected index to be less than array size");
LocationAttr locAttr = dyn_cast<LocationAttr>(metadataArr[i]);
assert(locAttr && "expected metadataArr entry to be location attribute");
loc = Location(locAttr);
}
return loc;
auto fieldsOp = this->getFieldsOp();
auto fieldLocs = fieldsOp.getFieldLocs();
if (!fieldLocs.has_value())
return fieldsOp.getLoc();
assert(i < fieldLocs.value().size() &&
"field index too large for location array");
return cast<LocationAttr>(fieldLocs.value()[i]);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -475,6 +495,24 @@ void circt::om::ClassExternOp::replaceFieldTypes(AttrTypeReplacer replacer) {
replaceClassLikeFieldTypes(*this, replacer);
}

//===----------------------------------------------------------------------===//
// ClassFieldsOp
//===----------------------------------------------------------------------===//
//
LogicalResult circt::om::ClassFieldsOp::verify() {
auto fieldLocs = this->getFieldLocs();
if (fieldLocs.has_value()) {
auto fieldLocsVal = fieldLocs.value();
if (fieldLocsVal.size() != this->getFields().size()) {
auto error = this->emitOpError("size of field_locs (")
<< fieldLocsVal.size()
<< ") does not match number of fields ("
<< this->getFields().size() << ")";
}
}
return success();
}

//===----------------------------------------------------------------------===//
// ObjectOp
//===----------------------------------------------------------------------===//
Expand Down
16 changes: 16 additions & 0 deletions test/Dialect/OM/errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,19 @@ om.class @A(%arg: i1) -> (a: i2) {
// expected-note @+1 {{see terminator:}}
om.class.fields %arg : i1
}


// -----

om.class @A(%arg: i1) -> (a: i1) {
// expected-error @+1 {{expected ')'}}
om.class.fields %arg : i1 field_locs([loc("loc0")], [loc("loc1")])
}


// -----

om.class @A(%arg: i1) -> (a: i1) {
// expected-error @+1 {{'om.class.fields' op size of field_locs (2) does not match number of fields (1)}}
om.class.fields %arg : i1 field_locs([loc("loc0"), loc("loc1")])
}
4 changes: 2 additions & 2 deletions test/Dialect/OM/round-trip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ om.class @Thingy(%blue_1: i8, %blue_2: i32) -> (widget: !om.class.type<@Widget>,
// CHECK: %[[widget_field:.+]] = om.object.field %[[widget]], [@blue_1] : (!om.class.type<@Widget>) -> i8
%6 = om.object.field %2, [@blue_1] : (!om.class.type<@Widget>) -> i8

// CHECK: om.class.fields {test = "fieldsAttr"} %2, %5, %blue_1, %6 : !om.class.type<@Widget>, !om.class.type<@Gadget>, i8, i8 loc("test")
om.class.fields {test = "fieldsAttr"} %2, %5, %blue_1, %6 : !om.class.type<@Widget>, !om.class.type<@Gadget>, i8, i8 loc("test")
// CHECK: om.class.fields {test = "fieldsAttr"} %2, %5, %blue_1, %6 : !om.class.type<@Widget>, !om.class.type<@Gadget>, i8, i8 field_locs([loc("loc0"), loc("loc1"), loc("loc2"), loc("loc3")]) loc("test")
om.class.fields {test = "fieldsAttr"} %2, %5, %blue_1, %6 : !om.class.type<@Widget>, !om.class.type<@Gadget>, i8, i8 field_locs([loc("loc0"), loc("loc1"), loc("loc2"), loc("loc3")]) loc("test")
}

// CHECK-LABEL: om.class @Widget
Expand Down
23 changes: 14 additions & 9 deletions unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ TEST(EvaluatorTests, GetFieldInvalidName) {
auto cls = builder.create<ClassOp>("MyClass");
auto &body = cls.getBody().emplaceBlock();
builder.setInsertionPointToStart(&body);
builder.create<ClassFieldsOp>(loc, llvm::ArrayRef<mlir::Value>());
builder.create<ClassFieldsOp>(loc, llvm::ArrayRef<mlir::Value>(),
ArrayAttr{});

Evaluator evaluator(mod);

Expand Down Expand Up @@ -253,7 +254,8 @@ TEST(EvaluatorTests, InstantiateObjectWithConstantField) {
builder.setInsertionPointToStart(&body);
auto constant = builder.create<ConstantOp>(
circt::om::IntegerAttr::get(&context, constantType));
builder.create<ClassFieldsOp>(loc, SmallVector<Value>({constant}));
builder.create<ClassFieldsOp>(loc, SmallVector<Value>({constant}),
ArrayAttr{});

Evaluator evaluator(mod);

Expand Down Expand Up @@ -304,7 +306,7 @@ TEST(EvaluatorTests, InstantiateObjectWithChildObject) {
body.addArgument(circt::om::OMIntegerType::get(&context), cls.getLoc());
builder.setInsertionPointToStart(&body);
auto object = builder.create<ObjectOp>(innerCls, body.getArguments());
builder.create<ClassFieldsOp>(loc, SmallVector<Value>({object}));
builder.create<ClassFieldsOp>(loc, SmallVector<Value>({object}), ArrayAttr{});

Evaluator evaluator(mod);

Expand Down Expand Up @@ -368,7 +370,7 @@ TEST(EvaluatorTests, InstantiateObjectWithFieldAccess) {
builder.create<ObjectFieldOp>(builder.getI32Type(), object,
builder.getArrayAttr(FlatSymbolRefAttr::get(
builder.getStringAttr("field"))));
builder.create<ClassFieldsOp>(loc, SmallVector<Value>({field}));
builder.create<ClassFieldsOp>(loc, SmallVector<Value>({field}), ArrayAttr{});

Evaluator evaluator(mod);

Expand Down Expand Up @@ -407,7 +409,8 @@ TEST(EvaluatorTests, InstantiateObjectWithChildObjectMemoized) {
auto innerCls = builder.create<ClassOp>("MyInnerClass");
auto &innerBody = innerCls.getBody().emplaceBlock();
builder.setInsertionPointToStart(&innerBody);
builder.create<ClassFieldsOp>(loc, llvm::ArrayRef<mlir::Value>());
builder.create<ClassFieldsOp>(loc, llvm::ArrayRef<mlir::Value>(),
ArrayAttr{});

builder.setInsertionPointToStart(&mod.getBodyRegion().front());
auto innerType = TypeAttr::get(ClassType::get(
Expand All @@ -422,7 +425,8 @@ TEST(EvaluatorTests, InstantiateObjectWithChildObjectMemoized) {
auto &body = cls.getBody().emplaceBlock();
builder.setInsertionPointToStart(&body);
auto object = builder.create<ObjectOp>(innerCls, body.getArguments());
builder.create<ClassFieldsOp>(loc, SmallVector<Value>({object, object}));
builder.create<ClassFieldsOp>(loc, SmallVector<Value>({object, object}),
ArrayAttr{});

Evaluator evaluator(mod);

Expand Down Expand Up @@ -476,7 +480,8 @@ TEST(EvaluatorTests, AnyCastObject) {
auto innerCls = builder.create<ClassOp>("MyInnerClass");
auto &innerBody = innerCls.getBody().emplaceBlock();
builder.setInsertionPointToStart(&innerBody);
builder.create<ClassFieldsOp>(loc, llvm::ArrayRef<mlir::Value>());
builder.create<ClassFieldsOp>(loc, llvm::ArrayRef<mlir::Value>(),
ArrayAttr{});

builder.setInsertionPointToStart(&mod.getBodyRegion().front());
auto innerType = TypeAttr::get(ClassType::get(
Expand All @@ -491,7 +496,7 @@ TEST(EvaluatorTests, AnyCastObject) {
builder.setInsertionPointToStart(&body);
auto object = builder.create<ObjectOp>(innerCls, body.getArguments());
auto cast = builder.create<AnyCastOp>(object);
builder.create<ClassFieldsOp>(loc, SmallVector<Value>({cast}));
builder.create<ClassFieldsOp>(loc, SmallVector<Value>({cast}), ArrayAttr{});

Evaluator evaluator(mod);

Expand Down Expand Up @@ -545,7 +550,7 @@ TEST(EvaluatorTests, AnyCastParam) {
auto cast = builder.create<AnyCastOp>(body.getArgument(0));
SmallVector<Value> objectParams = {cast};
auto object = builder.create<ObjectOp>(innerCls, objectParams);
builder.create<ClassFieldsOp>(loc, SmallVector<Value>({object}));
builder.create<ClassFieldsOp>(loc, SmallVector<Value>({object}), ArrayAttr{});

Evaluator evaluator(mod);

Expand Down
Loading